You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
479 lines
16 KiB
479 lines
16 KiB
|
2 weeks ago
|
import json
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import requests
|
||
|
|
import time
|
||
|
|
from flask import Flask, jsonify, Response, request
|
||
|
|
import pandas as pd
|
||
|
|
|
||
|
|
|
||
|
|
# flask配置
|
||
|
|
app = Flask(__name__)
|
||
|
|
app.config["JSON_AS_ASCII"] = False
|
||
|
|
# os.environ["WANDB_DISABLED"] = "true"
|
||
|
|
|
||
|
|
# 设置CUDA设备
|
||
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import os
|
||
|
|
import random
|
||
|
|
import sys
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import datasets
|
||
|
|
import evaluate
|
||
|
|
import numpy as np
|
||
|
|
from datasets import load_dataset
|
||
|
|
from tqdm import tqdm
|
||
|
|
import transformers
|
||
|
|
from transformers import (
|
||
|
|
AutoConfig,
|
||
|
|
AutoModelForSequenceClassification,
|
||
|
|
AutoTokenizer,
|
||
|
|
)
|
||
|
|
import torch
|
||
|
|
from tqdm import tqdm
|
||
|
|
|
||
|
|
'''
|
||
|
|
请求格式:
|
||
|
|
{
|
||
|
|
"content": "论文正文内容"
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
输出格式:
|
||
|
|
{
|
||
|
|
"code": 200,
|
||
|
|
"paper-lable":[
|
||
|
|
{
|
||
|
|
"index": 0,
|
||
|
|
"sentence" : "我参加的是17组的小组学习,主题是关于日本方言。我主要负责参",
|
||
|
|
lable: "正文"
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"index": 1,
|
||
|
|
"sentence" : "1.2.1 小组学习",
|
||
|
|
lable: "三级标题"
|
||
|
|
},
|
||
|
|
...
|
||
|
|
]
|
||
|
|
}
|
||
|
|
'''
|
||
|
|
|
||
|
|
# 检查GPU是否可用
|
||
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
|
print(f"使用设备: {device}")
|
||
|
|
|
||
|
|
|
||
|
|
lable_2_id_fenji = {
|
||
|
|
"标题": 0,
|
||
|
|
"正文": 1,
|
||
|
|
"无用类别": 2
|
||
|
|
}
|
||
|
|
|
||
|
|
id_2_lable_fenji = {}
|
||
|
|
for i in lable_2_id_fenji:
|
||
|
|
if lable_2_id_fenji[i] not in id_2_lable_fenji:
|
||
|
|
id_2_lable_fenji[lable_2_id_fenji[i]] = i
|
||
|
|
|
||
|
|
lable_2_id_title = {
|
||
|
|
"一级标题": 0,
|
||
|
|
"二级标题": 1,
|
||
|
|
"三级标题": 2,
|
||
|
|
"中文摘要标题": 3,
|
||
|
|
"致谢标题": 4,
|
||
|
|
"英文摘要标题": 5,
|
||
|
|
"参考文献标题": 6
|
||
|
|
}
|
||
|
|
|
||
|
|
id_2_lable_title = {}
|
||
|
|
for i in lable_2_id_title:
|
||
|
|
if lable_2_id_title[i] not in id_2_lable_title:
|
||
|
|
id_2_lable_title[lable_2_id_title[i]] = i
|
||
|
|
|
||
|
|
|
||
|
|
lable_2_id_content = {
|
||
|
|
"正文": 0,
|
||
|
|
"英文摘要": 1,
|
||
|
|
"中文摘要": 2,
|
||
|
|
"中文关键词": 3,
|
||
|
|
"英文关键词": 4,
|
||
|
|
"图": 5,
|
||
|
|
"表": 6,
|
||
|
|
"参考文献": 7
|
||
|
|
}
|
||
|
|
|
||
|
|
id_2_lable_content = {}
|
||
|
|
for i in lable_2_id_content:
|
||
|
|
if lable_2_id_content[i] not in id_2_lable_content:
|
||
|
|
id_2_lable_content[lable_2_id_content[i]] = i
|
||
|
|
|
||
|
|
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
|
"data_zong_roberta",
|
||
|
|
use_fast=True,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
model_name = "data_zong_roberta"
|
||
|
|
config = AutoConfig.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
num_labels=len(lable_2_id_fenji),
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False
|
||
|
|
)
|
||
|
|
model_roberta_zong = AutoModelForSequenceClassification.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
config=config,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
ignore_mismatched_sizes=False,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
model_name = "data_zong_roberta_no_start"
|
||
|
|
config = AutoConfig.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
num_labels=len(lable_2_id_fenji),
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False
|
||
|
|
)
|
||
|
|
model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
config=config,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
ignore_mismatched_sizes=False,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
model_name = "data_zong_roberta_no_end"
|
||
|
|
config = AutoConfig.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
num_labels=len(lable_2_id_fenji),
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False
|
||
|
|
)
|
||
|
|
model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
config=config,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
ignore_mismatched_sizes=False,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
model_name = "data_title_roberta"
|
||
|
|
config = AutoConfig.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
num_labels=len(lable_2_id_title),
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False
|
||
|
|
)
|
||
|
|
model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
config=config,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
ignore_mismatched_sizes=False,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
model_name = "data_content_roberta"
|
||
|
|
config = AutoConfig.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
num_labels=len(lable_2_id_content),
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False
|
||
|
|
)
|
||
|
|
model_content_roberta = AutoModelForSequenceClassification.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
config=config,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
ignore_mismatched_sizes=False,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
model_name = "data_content_roberta_no_end"
|
||
|
|
config = AutoConfig.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
num_labels=len(lable_2_id_content),
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False
|
||
|
|
)
|
||
|
|
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
config=config,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
ignore_mismatched_sizes=False,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
model_name = "data_content_roberta_no_start"
|
||
|
|
config = AutoConfig.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
num_labels=len(lable_2_id_content),
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False
|
||
|
|
)
|
||
|
|
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained(
|
||
|
|
model_name,
|
||
|
|
config=config,
|
||
|
|
revision="main",
|
||
|
|
trust_remote_code=False,
|
||
|
|
ignore_mismatched_sizes=False,
|
||
|
|
).to(device)
|
||
|
|
|
||
|
|
def gen_zong_cls(content_list):
|
||
|
|
|
||
|
|
paper_quanwen_lable_list = []
|
||
|
|
for index, paper_sen in content_list:
|
||
|
|
# 视野前后7句
|
||
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
|
||
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
|
||
|
|
# print(len(paper_start_list))
|
||
|
|
# print(len(paper_end_list))
|
||
|
|
paper_new_start = "\n".join(paper_start_list)
|
||
|
|
paper_new_end = "\n".join(paper_end_list)
|
||
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
|
||
|
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
|
||
|
|
|
||
|
|
# 视野前15句
|
||
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]]
|
||
|
|
# print(len(paper_start_list))
|
||
|
|
paper_new_start = "\n".join(paper_start_list)
|
||
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
|
||
|
|
paper_qian = "\n".join([paper_new_start, paper_object_dangqian])
|
||
|
|
|
||
|
|
# 视野后15句
|
||
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]]
|
||
|
|
# print(len(paper_end_list))
|
||
|
|
paper_new_end = "\n".join(paper_end_list)
|
||
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
|
||
|
|
paper_hou = "\n".join([paper_object_dangqian, paper_new_end])
|
||
|
|
|
||
|
|
# 目标句子在中间预测结果
|
||
|
|
sentence_list = [paper_zhong]
|
||
|
|
# sentence_list = [data[1][0]]
|
||
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
||
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
||
|
|
logits = model_roberta_zong(**result_on_device)
|
||
|
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
|
||
|
|
|
||
|
|
sentence_list = [paper_qian]
|
||
|
|
# sentence_list = [data[1][0]]
|
||
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
||
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
||
|
|
logits = model_roberta_zong_no_end(**result_on_device)
|
||
|
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
|
||
|
|
|
||
|
|
sentence_list = [paper_hou]
|
||
|
|
# sentence_list = [data[1][0]]
|
||
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
||
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
||
|
|
logits = model_roberta_zong_no_start(**result_on_device)
|
||
|
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
|
||
|
|
|
||
|
|
id_2_len = {}
|
||
|
|
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
|
||
|
|
if i not in id_2_len:
|
||
|
|
id_2_len[i] = 1
|
||
|
|
else:
|
||
|
|
id_2_len[i] += 1
|
||
|
|
|
||
|
|
queding = False
|
||
|
|
predicted_class_idx = ""
|
||
|
|
for i in id_2_len:
|
||
|
|
if id_2_len[i] >= 2:
|
||
|
|
queding = True
|
||
|
|
predicted_class_idx = i
|
||
|
|
break
|
||
|
|
|
||
|
|
if queding == False:
|
||
|
|
predicted_class_idx = 0
|
||
|
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]])
|
||
|
|
return paper_quanwen_lable_list
|
||
|
|
|
||
|
|
|
||
|
|
def gen_title_cls(content_list):
|
||
|
|
paper_quanwen_lable_list = []
|
||
|
|
for index, paper_sen in content_list:
|
||
|
|
|
||
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]]
|
||
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]]
|
||
|
|
# print(len(paper_start_list))
|
||
|
|
# print(len(paper_end_list))
|
||
|
|
paper_new_start = "\n".join(paper_start_list)
|
||
|
|
paper_new_end = "\n".join(paper_end_list)
|
||
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
|
||
|
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
|
||
|
|
paper_zhong = paper_zhong.strip("\n")
|
||
|
|
|
||
|
|
if len(paper_zhong) > 510:
|
||
|
|
data_paper_list = str(paper_zhong).split("\n")
|
||
|
|
start_index = 0
|
||
|
|
for i in range(len(data_paper_list)):
|
||
|
|
if "<Start>" in data_paper_list[i]:
|
||
|
|
start_index = i
|
||
|
|
break
|
||
|
|
left_end = 0
|
||
|
|
right_end = len(data_paper_list) - 1
|
||
|
|
left = start_index
|
||
|
|
right = start_index
|
||
|
|
left_end_bool = True
|
||
|
|
right_end_bool = True
|
||
|
|
old_sen = data_paper_list[start_index]
|
||
|
|
while True:
|
||
|
|
if left - 1 >= left_end:
|
||
|
|
left = left - 1
|
||
|
|
else:
|
||
|
|
left_end_bool = False
|
||
|
|
if right + 1 <= right_end:
|
||
|
|
right = right + 1
|
||
|
|
else:
|
||
|
|
right_end_bool = False
|
||
|
|
|
||
|
|
new_sen_list = [old_sen]
|
||
|
|
if left_end_bool == True:
|
||
|
|
new_sen_list = [data_paper_list[left]] + new_sen_list
|
||
|
|
if right_end_bool == True:
|
||
|
|
new_sen_list = new_sen_list + [data_paper_list[right]]
|
||
|
|
|
||
|
|
new_sen = "\n".join(new_sen_list)
|
||
|
|
if len(new_sen) > 510:
|
||
|
|
break
|
||
|
|
else:
|
||
|
|
old_sen = new_sen
|
||
|
|
paper_zhong = old_sen
|
||
|
|
|
||
|
|
# 目标句子在中间预测结果
|
||
|
|
sentence_list = [paper_zhong]
|
||
|
|
# sentence_list = [data[1][0]]
|
||
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
||
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
||
|
|
logits = model_title_roberta(**result_on_device)
|
||
|
|
predicted_class_idx = torch.argmax(logits[0], dim=1).item()
|
||
|
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_title[predicted_class_idx]])
|
||
|
|
|
||
|
|
return paper_quanwen_lable_list
|
||
|
|
|
||
|
|
|
||
|
|
def gen_content_cls(content_list):
|
||
|
|
paper_quanwen_lable_list = []
|
||
|
|
for index, paper_sen in content_list:
|
||
|
|
# 视野前后7句
|
||
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
|
||
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
|
||
|
|
# print(len(paper_start_list))
|
||
|
|
# print(len(paper_end_list))
|
||
|
|
paper_new_start = "\n".join(paper_start_list)
|
||
|
|
paper_new_end = "\n".join(paper_end_list)
|
||
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
|
||
|
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
|
||
|
|
|
||
|
|
# 视野前15句
|
||
|
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]]
|
||
|
|
# print(len(paper_start_list))
|
||
|
|
paper_new_start = "\n".join(paper_start_list)
|
||
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
|
||
|
|
paper_qian = "\n".join([paper_new_start, paper_object_dangqian])
|
||
|
|
|
||
|
|
# 视野后15句
|
||
|
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]]
|
||
|
|
# print(len(paper_end_list))
|
||
|
|
paper_new_end = "\n".join(paper_end_list)
|
||
|
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
|
||
|
|
paper_hou = "\n".join([paper_object_dangqian, paper_new_end])
|
||
|
|
|
||
|
|
# 目标句子在中间预测结果
|
||
|
|
sentence_list = [paper_zhong]
|
||
|
|
# sentence_list = [data[1][0]]
|
||
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
||
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
||
|
|
logits = model_content_roberta(**result_on_device)
|
||
|
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
|
||
|
|
|
||
|
|
sentence_list = [paper_qian]
|
||
|
|
# sentence_list = [data[1][0]]
|
||
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
||
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
||
|
|
logits = model_content_roberta_no_end(**result_on_device)
|
||
|
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
|
||
|
|
|
||
|
|
sentence_list = [paper_hou]
|
||
|
|
# sentence_list = [data[1][0]]
|
||
|
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
|
||
|
|
result_on_device = {key: value.to(device) for key, value in result.items()}
|
||
|
|
logits = model_content_roberta_no_start(**result_on_device)
|
||
|
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
|
||
|
|
|
||
|
|
id_2_len = {}
|
||
|
|
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
|
||
|
|
if i not in id_2_len:
|
||
|
|
id_2_len[i] = 1
|
||
|
|
else:
|
||
|
|
id_2_len[i] += 1
|
||
|
|
|
||
|
|
queding = False
|
||
|
|
predicted_class_idx = ""
|
||
|
|
for i in id_2_len:
|
||
|
|
if id_2_len[i] >= 2:
|
||
|
|
queding = True
|
||
|
|
predicted_class_idx = i
|
||
|
|
break
|
||
|
|
|
||
|
|
if queding == False:
|
||
|
|
predicted_class_idx = 0
|
||
|
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]])
|
||
|
|
return paper_quanwen_lable_list
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
def main(content: str):
|
||
|
|
|
||
|
|
# 先整理句子,把句子整理成模型需要的格式 [id, sen, lable]
|
||
|
|
paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))]
|
||
|
|
|
||
|
|
# 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来,
|
||
|
|
zong_list = gen_zong_cls(paper_content_list)
|
||
|
|
|
||
|
|
# 把标题数据和正文数据,无用类别数据做区分
|
||
|
|
title_data = []
|
||
|
|
content_data = []
|
||
|
|
|
||
|
|
for data_dan in zong_list:
|
||
|
|
if data_dan[2] == "标题":
|
||
|
|
title_data.append([data_dan[0], data_dan[1]])
|
||
|
|
if data_dan[2] == "正文":
|
||
|
|
content_data.append([data_dan[0], data_dan[1]])
|
||
|
|
# 把所有的标题类型提取出来,对每个标题区分标题级别
|
||
|
|
title_list = gen_title_cls(title_data)
|
||
|
|
|
||
|
|
# 把所有的正文类别提取出来,逐个进行打标
|
||
|
|
content_list = gen_content_cls(content_data)
|
||
|
|
|
||
|
|
paper_content_list_new = title_list + content_list
|
||
|
|
# 综合排序
|
||
|
|
paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0])
|
||
|
|
|
||
|
|
paper_content_info_list = []
|
||
|
|
|
||
|
|
for data_dan_info in paper_content_list_new:
|
||
|
|
paper_content_info_list.append({
|
||
|
|
"index": data_dan_info[0],
|
||
|
|
"sentence": data_dan_info[1],
|
||
|
|
"lable" : data_dan_info[2]
|
||
|
|
})
|
||
|
|
|
||
|
|
return paper_content_info_list
|
||
|
|
|
||
|
|
@app.route("/predict", methods=["POST"])
|
||
|
|
def search():
|
||
|
|
print(request.remote_addr)
|
||
|
|
content = request.json["content"]
|
||
|
|
response = main(content)
|
||
|
|
return jsonify(response) # 返回结果
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
app.run(host="0.0.0.0", port=28100, threaded=True, debug=False)
|