diff --git a/flask_api.py b/flask_api.py new file mode 100644 index 0000000..2636f0b --- /dev/null +++ b/flask_api.py @@ -0,0 +1,478 @@ +import json +import os +import re +import requests +import time +from flask import Flask, jsonify, Response, request +import pandas as pd + + +# flask配置 +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False +# os.environ["WANDB_DISABLED"] = "true" + +# 设置CUDA设备 +os.environ['CUDA_VISIBLE_DEVICES'] = '2' + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import datasets +import evaluate +import numpy as np +from datasets import load_dataset +from tqdm import tqdm +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, +) +import torch +from tqdm import tqdm + +''' +请求格式: +{ + "content": "论文正文内容" +} + + +输出格式: +{ + "code": 200, + "paper-lable":[ + { + "index": 0, + "sentence" : "我参加的是17组的小组学习,主题是关于日本方言。我主要负责参", + lable: "正文" + }, + { + "index": 1, + "sentence" : "1.2.1 小组学习", + lable: "三级标题" + }, + ... + ] +} +''' + +# 检查GPU是否可用 +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') +print(f"使用设备: {device}") + + +lable_2_id_fenji = { + "标题": 0, + "正文": 1, + "无用类别": 2 +} + +id_2_lable_fenji = {} +for i in lable_2_id_fenji: + if lable_2_id_fenji[i] not in id_2_lable_fenji: + id_2_lable_fenji[lable_2_id_fenji[i]] = i + +lable_2_id_title = { + "一级标题": 0, + "二级标题": 1, + "三级标题": 2, + "中文摘要标题": 3, + "致谢标题": 4, + "英文摘要标题": 5, + "参考文献标题": 6 +} + +id_2_lable_title = {} +for i in lable_2_id_title: + if lable_2_id_title[i] not in id_2_lable_title: + id_2_lable_title[lable_2_id_title[i]] = i + + +lable_2_id_content = { + "正文": 0, + "英文摘要": 1, + "中文摘要": 2, + "中文关键词": 3, + "英文关键词": 4, + "图": 5, + "表": 6, + "参考文献": 7 +} + +id_2_lable_content = {} +for i in lable_2_id_content: + if lable_2_id_content[i] not in id_2_lable_content: + id_2_lable_content[lable_2_id_content[i]] = i + + +tokenizer = AutoTokenizer.from_pretrained( + "data_zong_roberta", + use_fast=True, + revision="main", + trust_remote_code=False, +) + +model_name = "data_zong_roberta" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_fenji), + revision="main", + trust_remote_code=False +) +model_roberta_zong = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_zong_roberta_no_start" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_fenji), + revision="main", + trust_remote_code=False +) +model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_zong_roberta_no_end" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_fenji), + revision="main", + trust_remote_code=False +) +model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_title_roberta" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_title), + revision="main", + trust_remote_code=False +) +model_title_roberta = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_content_roberta" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_content), + revision="main", + trust_remote_code=False +) +model_content_roberta = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_content_roberta_no_end" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_content), + revision="main", + trust_remote_code=False +) +model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_content_roberta_no_start" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_content), + revision="main", + trust_remote_code=False +) +model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +def gen_zong_cls(content_list): + + paper_quanwen_lable_list = [] + for index, paper_sen in content_list: + # 视野前后7句 + paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] + paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] + # print(len(paper_start_list)) + # print(len(paper_end_list)) + paper_new_start = "\n".join(paper_start_list) + paper_new_end = "\n".join(paper_end_list) + paper_object_dangqian = "" + paper_sen + "" + paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) + + # 视野前15句 + paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]] + # print(len(paper_start_list)) + paper_new_start = "\n".join(paper_start_list) + paper_object_dangqian = "" + paper_sen + "" + paper_qian = "\n".join([paper_new_start, paper_object_dangqian]) + + # 视野后15句 + paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]] + # print(len(paper_end_list)) + paper_new_end = "\n".join(paper_end_list) + paper_object_dangqian = "" + paper_sen + "" + paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) + + # 目标句子在中间预测结果 + sentence_list = [paper_zhong] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_roberta_zong(**result_on_device) + predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() + + sentence_list = [paper_qian] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_roberta_zong_no_end(**result_on_device) + predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() + + sentence_list = [paper_hou] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_roberta_zong_no_start(**result_on_device) + predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() + + id_2_len = {} + for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: + if i not in id_2_len: + id_2_len[i] = 1 + else: + id_2_len[i] += 1 + + queding = False + predicted_class_idx = "" + for i in id_2_len: + if id_2_len[i] >= 2: + queding = True + predicted_class_idx = i + break + + if queding == False: + predicted_class_idx = 0 + paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]]) + return paper_quanwen_lable_list + + +def gen_title_cls(content_list): + paper_quanwen_lable_list = [] + for index, paper_sen in content_list: + + paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]] + paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]] + # print(len(paper_start_list)) + # print(len(paper_end_list)) + paper_new_start = "\n".join(paper_start_list) + paper_new_end = "\n".join(paper_end_list) + paper_object_dangqian = "" + paper_sen + "" + paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) + paper_zhong = paper_zhong.strip("\n") + + if len(paper_zhong) > 510: + data_paper_list = str(paper_zhong).split("\n") + start_index = 0 + for i in range(len(data_paper_list)): + if "" in data_paper_list[i]: + start_index = i + break + left_end = 0 + right_end = len(data_paper_list) - 1 + left = start_index + right = start_index + left_end_bool = True + right_end_bool = True + old_sen = data_paper_list[start_index] + while True: + if left - 1 >= left_end: + left = left - 1 + else: + left_end_bool = False + if right + 1 <= right_end: + right = right + 1 + else: + right_end_bool = False + + new_sen_list = [old_sen] + if left_end_bool == True: + new_sen_list = [data_paper_list[left]] + new_sen_list + if right_end_bool == True: + new_sen_list = new_sen_list + [data_paper_list[right]] + + new_sen = "\n".join(new_sen_list) + if len(new_sen) > 510: + break + else: + old_sen = new_sen + paper_zhong = old_sen + + # 目标句子在中间预测结果 + sentence_list = [paper_zhong] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_title_roberta(**result_on_device) + predicted_class_idx = torch.argmax(logits[0], dim=1).item() + paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_title[predicted_class_idx]]) + + return paper_quanwen_lable_list + + +def gen_content_cls(content_list): + paper_quanwen_lable_list = [] + for index, paper_sen in content_list: + # 视野前后7句 + paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] + paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] + # print(len(paper_start_list)) + # print(len(paper_end_list)) + paper_new_start = "\n".join(paper_start_list) + paper_new_end = "\n".join(paper_end_list) + paper_object_dangqian = "" + paper_sen + "" + paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) + + # 视野前15句 + paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]] + # print(len(paper_start_list)) + paper_new_start = "\n".join(paper_start_list) + paper_object_dangqian = "" + paper_sen + "" + paper_qian = "\n".join([paper_new_start, paper_object_dangqian]) + + # 视野后15句 + paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]] + # print(len(paper_end_list)) + paper_new_end = "\n".join(paper_end_list) + paper_object_dangqian = "" + paper_sen + "" + paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) + + # 目标句子在中间预测结果 + sentence_list = [paper_zhong] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_content_roberta(**result_on_device) + predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() + + sentence_list = [paper_qian] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_content_roberta_no_end(**result_on_device) + predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() + + sentence_list = [paper_hou] + # sentence_list = [data[1][0]] + result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") + result_on_device = {key: value.to(device) for key, value in result.items()} + logits = model_content_roberta_no_start(**result_on_device) + predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() + + id_2_len = {} + for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: + if i not in id_2_len: + id_2_len[i] = 1 + else: + id_2_len[i] += 1 + + queding = False + predicted_class_idx = "" + for i in id_2_len: + if id_2_len[i] >= 2: + queding = True + predicted_class_idx = i + break + + if queding == False: + predicted_class_idx = 0 + paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]]) + return paper_quanwen_lable_list + + + +def main(content: str): + + # 先整理句子,把句子整理成模型需要的格式 [id, sen, lable] + paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))] + + # 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来, + zong_list = gen_zong_cls(paper_content_list) + + # 把标题数据和正文数据,无用类别数据做区分 + title_data = [] + content_data = [] + + for data_dan in zong_list: + if data_dan[2] == "标题": + title_data.append([data_dan[0], data_dan[1]]) + if data_dan[2] == "正文": + content_data.append([data_dan[0], data_dan[1]]) + # 把所有的标题类型提取出来,对每个标题区分标题级别 + title_list = gen_title_cls(title_data) + + # 把所有的正文类别提取出来,逐个进行打标 + content_list = gen_content_cls(content_data) + + paper_content_list_new = title_list + content_list + # 综合排序 + paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0]) + + paper_content_info_list = [] + + for data_dan_info in paper_content_list_new: + paper_content_info_list.append({ + "index": data_dan_info[0], + "sentence": data_dan_info[1], + "lable" : data_dan_info[2] + }) + + return paper_content_info_list + +@app.route("/predict", methods=["POST"]) +def search(): + print(request.remote_addr) + content = request.json["content"] + response = main(content) + return jsonify(response) # 返回结果 + + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=28100, threaded=True, debug=False) diff --git a/run_api.sh b/run_api.sh new file mode 100644 index 0000000..b6b6a23 --- /dev/null +++ b/run_api.sh @@ -0,0 +1 @@ +nohup python flask_api.py > main.log 2>&1 &