import json import os import re import requests import time from flask import Flask, jsonify, Response, request import pandas as pd # flask配置 app = Flask(__name__) app.config["JSON_AS_ASCII"] = False # os.environ["WANDB_DISABLED"] = "true" # 设置CUDA设备 os.environ['CUDA_VISIBLE_DEVICES'] = '2' import logging import os import random import sys from dataclasses import dataclass, field from typing import Optional import datasets import evaluate import numpy as np from datasets import load_dataset from tqdm import tqdm import transformers from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, ) import torch from tqdm import tqdm ''' 请求格式: { "content": "论文正文内容" } 输出格式: { "code": 200, "paper-lable":[ { "index": 0, "sentence" : "我参加的是17组的小组学习,主题是关于日本方言。我主要负责参", lable: "正文" }, { "index": 1, "sentence" : "1.2.1 小组学习", lable: "三级标题" }, ... ] } ''' # 检查GPU是否可用 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"使用设备: {device}") lable_2_id_fenji = { "标题": 0, "正文": 1, "无用类别": 2 } id_2_lable_fenji = {} for i in lable_2_id_fenji: if lable_2_id_fenji[i] not in id_2_lable_fenji: id_2_lable_fenji[lable_2_id_fenji[i]] = i lable_2_id_title = { "一级标题": 0, "二级标题": 1, "三级标题": 2, "中文摘要标题": 3, "致谢标题": 4, "英文摘要标题": 5, "参考文献标题": 6 } id_2_lable_title = {} for i in lable_2_id_title: if lable_2_id_title[i] not in id_2_lable_title: id_2_lable_title[lable_2_id_title[i]] = i lable_2_id_content = { "正文": 0, "英文摘要": 1, "中文摘要": 2, "中文关键词": 3, "英文关键词": 4, "图": 5, "表": 6, "参考文献": 7 } id_2_lable_content = {} for i in lable_2_id_content: if lable_2_id_content[i] not in id_2_lable_content: id_2_lable_content[lable_2_id_content[i]] = i tokenizer = AutoTokenizer.from_pretrained( "data_zong_roberta", use_fast=True, revision="main", trust_remote_code=False, ) model_name = "data_zong_roberta" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), revision="main", trust_remote_code=False ) model_roberta_zong = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_zong_roberta_no_start" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), revision="main", trust_remote_code=False ) model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_zong_roberta_no_end" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), revision="main", trust_remote_code=False ) model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_title_roberta" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_title), revision="main", trust_remote_code=False ) model_title_roberta = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_content_roberta" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_content), revision="main", trust_remote_code=False ) model_content_roberta = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_content_roberta_no_end" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_content), revision="main", trust_remote_code=False ) model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) model_name = "data_content_roberta_no_start" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_content), revision="main", trust_remote_code=False ) model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", trust_remote_code=False, ignore_mismatched_sizes=False, ).to(device) def gen_zong_cls(content_list): paper_quanwen_lable_list = [] for index, paper_sen in content_list: # 视野前后7句 paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] # print(len(paper_start_list)) # print(len(paper_end_list)) paper_new_start = "\n".join(paper_start_list) paper_new_end = "\n".join(paper_end_list) paper_object_dangqian = "" + paper_sen + "" paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) # 视野前15句 paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]] # print(len(paper_start_list)) paper_new_start = "\n".join(paper_start_list) paper_object_dangqian = "" + paper_sen + "" paper_qian = "\n".join([paper_new_start, paper_object_dangqian]) # 视野后15句 paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]] # print(len(paper_end_list)) paper_new_end = "\n".join(paper_end_list) paper_object_dangqian = "" + paper_sen + "" paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) # 目标句子在中间预测结果 sentence_list = [paper_zhong] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_zong(**result_on_device) predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() sentence_list = [paper_qian] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_zong_no_end(**result_on_device) predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() sentence_list = [paper_hou] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_zong_no_start(**result_on_device) predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() id_2_len = {} for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: if i not in id_2_len: id_2_len[i] = 1 else: id_2_len[i] += 1 queding = False predicted_class_idx = "" for i in id_2_len: if id_2_len[i] >= 2: queding = True predicted_class_idx = i break if queding == False: predicted_class_idx = 0 paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]]) return paper_quanwen_lable_list def gen_title_cls(content_list): paper_quanwen_lable_list = [] for index, paper_sen in content_list: paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]] paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]] # print(len(paper_start_list)) # print(len(paper_end_list)) paper_new_start = "\n".join(paper_start_list) paper_new_end = "\n".join(paper_end_list) paper_object_dangqian = "" + paper_sen + "" paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) paper_zhong = paper_zhong.strip("\n") if len(paper_zhong) > 510: data_paper_list = str(paper_zhong).split("\n") start_index = 0 for i in range(len(data_paper_list)): if "" in data_paper_list[i]: start_index = i break left_end = 0 right_end = len(data_paper_list) - 1 left = start_index right = start_index left_end_bool = True right_end_bool = True old_sen = data_paper_list[start_index] while True: if left - 1 >= left_end: left = left - 1 else: left_end_bool = False if right + 1 <= right_end: right = right + 1 else: right_end_bool = False new_sen_list = [old_sen] if left_end_bool == True: new_sen_list = [data_paper_list[left]] + new_sen_list if right_end_bool == True: new_sen_list = new_sen_list + [data_paper_list[right]] new_sen = "\n".join(new_sen_list) if len(new_sen) > 510: break else: old_sen = new_sen paper_zhong = old_sen # 目标句子在中间预测结果 sentence_list = [paper_zhong] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_title_roberta(**result_on_device) predicted_class_idx = torch.argmax(logits[0], dim=1).item() paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_title[predicted_class_idx]]) return paper_quanwen_lable_list def gen_content_cls(content_list): paper_quanwen_lable_list = [] for index, paper_sen in content_list: # 视野前后7句 paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] # print(len(paper_start_list)) # print(len(paper_end_list)) paper_new_start = "\n".join(paper_start_list) paper_new_end = "\n".join(paper_end_list) paper_object_dangqian = "" + paper_sen + "" paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) # 视野前15句 paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]] # print(len(paper_start_list)) paper_new_start = "\n".join(paper_start_list) paper_object_dangqian = "" + paper_sen + "" paper_qian = "\n".join([paper_new_start, paper_object_dangqian]) # 视野后15句 paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]] # print(len(paper_end_list)) paper_new_end = "\n".join(paper_end_list) paper_object_dangqian = "" + paper_sen + "" paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) # 目标句子在中间预测结果 sentence_list = [paper_zhong] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_content_roberta(**result_on_device) predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() sentence_list = [paper_qian] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_content_roberta_no_end(**result_on_device) predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() sentence_list = [paper_hou] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_content_roberta_no_start(**result_on_device) predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() id_2_len = {} for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: if i not in id_2_len: id_2_len[i] = 1 else: id_2_len[i] += 1 queding = False predicted_class_idx = "" for i in id_2_len: if id_2_len[i] >= 2: queding = True predicted_class_idx = i break if queding == False: predicted_class_idx = 0 paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]]) return paper_quanwen_lable_list def main(content: str): # 先整理句子,把句子整理成模型需要的格式 [id, sen, lable] paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))] # 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来, zong_list = gen_zong_cls(paper_content_list) # 把标题数据和正文数据,无用类别数据做区分 title_data = [] content_data = [] for data_dan in zong_list: if data_dan[2] == "标题": title_data.append([data_dan[0], data_dan[1]]) if data_dan[2] == "正文": content_data.append([data_dan[0], data_dan[1]]) # 把所有的标题类型提取出来,对每个标题区分标题级别 title_list = gen_title_cls(title_data) # 把所有的正文类别提取出来,逐个进行打标 content_list = gen_content_cls(content_data) paper_content_list_new = title_list + content_list # 综合排序 paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0]) paper_content_info_list = [] for data_dan_info in paper_content_list_new: paper_content_info_list.append({ "index": data_dan_info[0], "sentence": data_dan_info[1], "lable" : data_dan_info[2] }) return paper_content_info_list @app.route("/predict", methods=["POST"]) def search(): print(request.remote_addr) content = request.json["content"] response = main(content) return jsonify(response) # 返回结果 if __name__ == "__main__": app.run(host="0.0.0.0", port=28100, threaded=True, debug=False)