2 changed files with 479 additions and 0 deletions
@ -0,0 +1,478 @@ |
|||||
|
import json |
||||
|
import os |
||||
|
import re |
||||
|
import requests |
||||
|
import time |
||||
|
from flask import Flask, jsonify, Response, request |
||||
|
import pandas as pd |
||||
|
|
||||
|
|
||||
|
# flask配置 |
||||
|
app = Flask(__name__) |
||||
|
app.config["JSON_AS_ASCII"] = False |
||||
|
# os.environ["WANDB_DISABLED"] = "true" |
||||
|
|
||||
|
# 设置CUDA设备 |
||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = '2' |
||||
|
|
||||
|
import logging |
||||
|
import os |
||||
|
import random |
||||
|
import sys |
||||
|
from dataclasses import dataclass, field |
||||
|
from typing import Optional |
||||
|
|
||||
|
import datasets |
||||
|
import evaluate |
||||
|
import numpy as np |
||||
|
from datasets import load_dataset |
||||
|
from tqdm import tqdm |
||||
|
import transformers |
||||
|
from transformers import ( |
||||
|
AutoConfig, |
||||
|
AutoModelForSequenceClassification, |
||||
|
AutoTokenizer, |
||||
|
) |
||||
|
import torch |
||||
|
from tqdm import tqdm |
||||
|
|
||||
|
''' |
||||
|
请求格式: |
||||
|
{ |
||||
|
"content": "论文正文内容" |
||||
|
} |
||||
|
|
||||
|
|
||||
|
输出格式: |
||||
|
{ |
||||
|
"code": 200, |
||||
|
"paper-lable":[ |
||||
|
{ |
||||
|
"index": 0, |
||||
|
"sentence" : "我参加的是17组的小组学习,主题是关于日本方言。我主要负责参", |
||||
|
lable: "正文" |
||||
|
}, |
||||
|
{ |
||||
|
"index": 1, |
||||
|
"sentence" : "1.2.1 小组学习", |
||||
|
lable: "三级标题" |
||||
|
}, |
||||
|
... |
||||
|
] |
||||
|
} |
||||
|
''' |
||||
|
|
||||
|
# 检查GPU是否可用 |
||||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
||||
|
print(f"使用设备: {device}") |
||||
|
|
||||
|
|
||||
|
lable_2_id_fenji = { |
||||
|
"标题": 0, |
||||
|
"正文": 1, |
||||
|
"无用类别": 2 |
||||
|
} |
||||
|
|
||||
|
id_2_lable_fenji = {} |
||||
|
for i in lable_2_id_fenji: |
||||
|
if lable_2_id_fenji[i] not in id_2_lable_fenji: |
||||
|
id_2_lable_fenji[lable_2_id_fenji[i]] = i |
||||
|
|
||||
|
lable_2_id_title = { |
||||
|
"一级标题": 0, |
||||
|
"二级标题": 1, |
||||
|
"三级标题": 2, |
||||
|
"中文摘要标题": 3, |
||||
|
"致谢标题": 4, |
||||
|
"英文摘要标题": 5, |
||||
|
"参考文献标题": 6 |
||||
|
} |
||||
|
|
||||
|
id_2_lable_title = {} |
||||
|
for i in lable_2_id_title: |
||||
|
if lable_2_id_title[i] not in id_2_lable_title: |
||||
|
id_2_lable_title[lable_2_id_title[i]] = i |
||||
|
|
||||
|
|
||||
|
lable_2_id_content = { |
||||
|
"正文": 0, |
||||
|
"英文摘要": 1, |
||||
|
"中文摘要": 2, |
||||
|
"中文关键词": 3, |
||||
|
"英文关键词": 4, |
||||
|
"图": 5, |
||||
|
"表": 6, |
||||
|
"参考文献": 7 |
||||
|
} |
||||
|
|
||||
|
id_2_lable_content = {} |
||||
|
for i in lable_2_id_content: |
||||
|
if lable_2_id_content[i] not in id_2_lable_content: |
||||
|
id_2_lable_content[lable_2_id_content[i]] = i |
||||
|
|
||||
|
|
||||
|
tokenizer = AutoTokenizer.from_pretrained( |
||||
|
"data_zong_roberta", |
||||
|
use_fast=True, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
) |
||||
|
|
||||
|
model_name = "data_zong_roberta" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_fenji), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_roberta_zong = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_zong_roberta_no_start" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_fenji), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_zong_roberta_no_end" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_fenji), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_title_roberta" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_title), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_title_roberta = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_content_roberta" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_content), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_content_roberta = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_content_roberta_no_end" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_content), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
model_name = "data_content_roberta_no_start" |
||||
|
config = AutoConfig.from_pretrained( |
||||
|
model_name, |
||||
|
num_labels=len(lable_2_id_content), |
||||
|
revision="main", |
||||
|
trust_remote_code=False |
||||
|
) |
||||
|
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( |
||||
|
model_name, |
||||
|
config=config, |
||||
|
revision="main", |
||||
|
trust_remote_code=False, |
||||
|
ignore_mismatched_sizes=False, |
||||
|
).to(device) |
||||
|
|
||||
|
def gen_zong_cls(content_list): |
||||
|
|
||||
|
paper_quanwen_lable_list = [] |
||||
|
for index, paper_sen in content_list: |
||||
|
# 视野前后7句 |
||||
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] |
||||
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] |
||||
|
# print(len(paper_start_list)) |
||||
|
# print(len(paper_end_list)) |
||||
|
paper_new_start = "\n".join(paper_start_list) |
||||
|
paper_new_end = "\n".join(paper_end_list) |
||||
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
||||
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) |
||||
|
|
||||
|
# 视野前15句 |
||||
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]] |
||||
|
# print(len(paper_start_list)) |
||||
|
paper_new_start = "\n".join(paper_start_list) |
||||
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
||||
|
paper_qian = "\n".join([paper_new_start, paper_object_dangqian]) |
||||
|
|
||||
|
# 视野后15句 |
||||
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]] |
||||
|
# print(len(paper_end_list)) |
||||
|
paper_new_end = "\n".join(paper_end_list) |
||||
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
||||
|
paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) |
||||
|
|
||||
|
# 目标句子在中间预测结果 |
||||
|
sentence_list = [paper_zhong] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_roberta_zong(**result_on_device) |
||||
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
sentence_list = [paper_qian] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_roberta_zong_no_end(**result_on_device) |
||||
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
sentence_list = [paper_hou] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_roberta_zong_no_start(**result_on_device) |
||||
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
id_2_len = {} |
||||
|
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: |
||||
|
if i not in id_2_len: |
||||
|
id_2_len[i] = 1 |
||||
|
else: |
||||
|
id_2_len[i] += 1 |
||||
|
|
||||
|
queding = False |
||||
|
predicted_class_idx = "" |
||||
|
for i in id_2_len: |
||||
|
if id_2_len[i] >= 2: |
||||
|
queding = True |
||||
|
predicted_class_idx = i |
||||
|
break |
||||
|
|
||||
|
if queding == False: |
||||
|
predicted_class_idx = 0 |
||||
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]]) |
||||
|
return paper_quanwen_lable_list |
||||
|
|
||||
|
|
||||
|
def gen_title_cls(content_list): |
||||
|
paper_quanwen_lable_list = [] |
||||
|
for index, paper_sen in content_list: |
||||
|
|
||||
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]] |
||||
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]] |
||||
|
# print(len(paper_start_list)) |
||||
|
# print(len(paper_end_list)) |
||||
|
paper_new_start = "\n".join(paper_start_list) |
||||
|
paper_new_end = "\n".join(paper_end_list) |
||||
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
||||
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) |
||||
|
paper_zhong = paper_zhong.strip("\n") |
||||
|
|
||||
|
if len(paper_zhong) > 510: |
||||
|
data_paper_list = str(paper_zhong).split("\n") |
||||
|
start_index = 0 |
||||
|
for i in range(len(data_paper_list)): |
||||
|
if "<Start>" in data_paper_list[i]: |
||||
|
start_index = i |
||||
|
break |
||||
|
left_end = 0 |
||||
|
right_end = len(data_paper_list) - 1 |
||||
|
left = start_index |
||||
|
right = start_index |
||||
|
left_end_bool = True |
||||
|
right_end_bool = True |
||||
|
old_sen = data_paper_list[start_index] |
||||
|
while True: |
||||
|
if left - 1 >= left_end: |
||||
|
left = left - 1 |
||||
|
else: |
||||
|
left_end_bool = False |
||||
|
if right + 1 <= right_end: |
||||
|
right = right + 1 |
||||
|
else: |
||||
|
right_end_bool = False |
||||
|
|
||||
|
new_sen_list = [old_sen] |
||||
|
if left_end_bool == True: |
||||
|
new_sen_list = [data_paper_list[left]] + new_sen_list |
||||
|
if right_end_bool == True: |
||||
|
new_sen_list = new_sen_list + [data_paper_list[right]] |
||||
|
|
||||
|
new_sen = "\n".join(new_sen_list) |
||||
|
if len(new_sen) > 510: |
||||
|
break |
||||
|
else: |
||||
|
old_sen = new_sen |
||||
|
paper_zhong = old_sen |
||||
|
|
||||
|
# 目标句子在中间预测结果 |
||||
|
sentence_list = [paper_zhong] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_title_roberta(**result_on_device) |
||||
|
predicted_class_idx = torch.argmax(logits[0], dim=1).item() |
||||
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_title[predicted_class_idx]]) |
||||
|
|
||||
|
return paper_quanwen_lable_list |
||||
|
|
||||
|
|
||||
|
def gen_content_cls(content_list): |
||||
|
paper_quanwen_lable_list = [] |
||||
|
for index, paper_sen in content_list: |
||||
|
# 视野前后7句 |
||||
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] |
||||
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] |
||||
|
# print(len(paper_start_list)) |
||||
|
# print(len(paper_end_list)) |
||||
|
paper_new_start = "\n".join(paper_start_list) |
||||
|
paper_new_end = "\n".join(paper_end_list) |
||||
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
||||
|
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) |
||||
|
|
||||
|
# 视野前15句 |
||||
|
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]] |
||||
|
# print(len(paper_start_list)) |
||||
|
paper_new_start = "\n".join(paper_start_list) |
||||
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
||||
|
paper_qian = "\n".join([paper_new_start, paper_object_dangqian]) |
||||
|
|
||||
|
# 视野后15句 |
||||
|
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]] |
||||
|
# print(len(paper_end_list)) |
||||
|
paper_new_end = "\n".join(paper_end_list) |
||||
|
paper_object_dangqian = "<Start>" + paper_sen + "<End>" |
||||
|
paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) |
||||
|
|
||||
|
# 目标句子在中间预测结果 |
||||
|
sentence_list = [paper_zhong] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_content_roberta(**result_on_device) |
||||
|
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
sentence_list = [paper_qian] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_content_roberta_no_end(**result_on_device) |
||||
|
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
sentence_list = [paper_hou] |
||||
|
# sentence_list = [data[1][0]] |
||||
|
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") |
||||
|
result_on_device = {key: value.to(device) for key, value in result.items()} |
||||
|
logits = model_content_roberta_no_start(**result_on_device) |
||||
|
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() |
||||
|
|
||||
|
id_2_len = {} |
||||
|
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: |
||||
|
if i not in id_2_len: |
||||
|
id_2_len[i] = 1 |
||||
|
else: |
||||
|
id_2_len[i] += 1 |
||||
|
|
||||
|
queding = False |
||||
|
predicted_class_idx = "" |
||||
|
for i in id_2_len: |
||||
|
if id_2_len[i] >= 2: |
||||
|
queding = True |
||||
|
predicted_class_idx = i |
||||
|
break |
||||
|
|
||||
|
if queding == False: |
||||
|
predicted_class_idx = 0 |
||||
|
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]]) |
||||
|
return paper_quanwen_lable_list |
||||
|
|
||||
|
|
||||
|
|
||||
|
def main(content: str): |
||||
|
|
||||
|
# 先整理句子,把句子整理成模型需要的格式 [id, sen, lable] |
||||
|
paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))] |
||||
|
|
||||
|
# 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来, |
||||
|
zong_list = gen_zong_cls(paper_content_list) |
||||
|
|
||||
|
# 把标题数据和正文数据,无用类别数据做区分 |
||||
|
title_data = [] |
||||
|
content_data = [] |
||||
|
|
||||
|
for data_dan in zong_list: |
||||
|
if data_dan[2] == "标题": |
||||
|
title_data.append([data_dan[0], data_dan[1]]) |
||||
|
if data_dan[2] == "正文": |
||||
|
content_data.append([data_dan[0], data_dan[1]]) |
||||
|
# 把所有的标题类型提取出来,对每个标题区分标题级别 |
||||
|
title_list = gen_title_cls(title_data) |
||||
|
|
||||
|
# 把所有的正文类别提取出来,逐个进行打标 |
||||
|
content_list = gen_content_cls(content_data) |
||||
|
|
||||
|
paper_content_list_new = title_list + content_list |
||||
|
# 综合排序 |
||||
|
paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0]) |
||||
|
|
||||
|
paper_content_info_list = [] |
||||
|
|
||||
|
for data_dan_info in paper_content_list_new: |
||||
|
paper_content_info_list.append({ |
||||
|
"index": data_dan_info[0], |
||||
|
"sentence": data_dan_info[1], |
||||
|
"lable" : data_dan_info[2] |
||||
|
}) |
||||
|
|
||||
|
return paper_content_info_list |
||||
|
|
||||
|
@app.route("/predict", methods=["POST"]) |
||||
|
def search(): |
||||
|
print(request.remote_addr) |
||||
|
content = request.json["content"] |
||||
|
response = main(content) |
||||
|
return jsonify(response) # 返回结果 |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
app.run(host="0.0.0.0", port=28100, threaded=True, debug=False) |
||||
@ -0,0 +1 @@ |
|||||
|
nohup python flask_api.py > main.log 2>&1 & |
||||
Loading…
Reference in new issue