排版识别标题级别和正文
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

479 lines
16 KiB

import json
import os
import re
import requests
import time
from flask import Flask, jsonify, Response, request
import pandas as pd
# flask配置
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
# os.environ["WANDB_DISABLED"] = "true"
# 设置CUDA设备
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
)
import torch
from tqdm import tqdm
'''
请求格式
{
"content": "论文正文内容"
}
输出格式
{
"code": 200,
"paper-lable":[
{
"index": 0,
"sentence" : "我参加的是17组的小组学习,主题是关于日本方言。我主要负责参",
lable: "正文"
},
{
"index": 1,
"sentence" : "1.2.1 小组学习",
lable: "三级标题"
},
...
]
}
'''
# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
lable_2_id_fenji = {
"标题": 0,
"正文": 1,
"无用类别": 2
}
id_2_lable_fenji = {}
for i in lable_2_id_fenji:
if lable_2_id_fenji[i] not in id_2_lable_fenji:
id_2_lable_fenji[lable_2_id_fenji[i]] = i
lable_2_id_title = {
"一级标题": 0,
"二级标题": 1,
"三级标题": 2,
"中文摘要标题": 3,
"致谢标题": 4,
"英文摘要标题": 5,
"参考文献标题": 6
}
id_2_lable_title = {}
for i in lable_2_id_title:
if lable_2_id_title[i] not in id_2_lable_title:
id_2_lable_title[lable_2_id_title[i]] = i
lable_2_id_content = {
"正文": 0,
"英文摘要": 1,
"中文摘要": 2,
"中文关键词": 3,
"英文关键词": 4,
"": 5,
"": 6,
"参考文献": 7
}
id_2_lable_content = {}
for i in lable_2_id_content:
if lable_2_id_content[i] not in id_2_lable_content:
id_2_lable_content[lable_2_id_content[i]] = i
tokenizer = AutoTokenizer.from_pretrained(
"data_zong_roberta",
use_fast=True,
revision="main",
trust_remote_code=False,
)
model_name = "data_zong_roberta"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_zong = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_roberta_no_start"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_roberta_no_end"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_title_roberta"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_title),
revision="main",
trust_remote_code=False
)
model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta_no_end"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta_no_start"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
def gen_zong_cls(content_list):
paper_quanwen_lable_list = []
for index, paper_sen in content_list:
# 视野前后7句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
# print(len(paper_start_list))
# print(len(paper_end_list))
paper_new_start = "\n".join(paper_start_list)
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
# 视野前15句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]]
# print(len(paper_start_list))
paper_new_start = "\n".join(paper_start_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_qian = "\n".join([paper_new_start, paper_object_dangqian])
# 视野后15句
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]]
# print(len(paper_end_list))
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_hou = "\n".join([paper_object_dangqian, paper_new_end])
# 目标句子在中间预测结果
sentence_list = [paper_zhong]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong(**result_on_device)
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
sentence_list = [paper_qian]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong_no_end(**result_on_device)
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
sentence_list = [paper_hou]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong_no_start(**result_on_device)
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
id_2_len = {}
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
if i not in id_2_len:
id_2_len[i] = 1
else:
id_2_len[i] += 1
queding = False
predicted_class_idx = ""
for i in id_2_len:
if id_2_len[i] >= 2:
queding = True
predicted_class_idx = i
break
if queding == False:
predicted_class_idx = 0
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]])
return paper_quanwen_lable_list
def gen_title_cls(content_list):
paper_quanwen_lable_list = []
for index, paper_sen in content_list:
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]]
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]]
# print(len(paper_start_list))
# print(len(paper_end_list))
paper_new_start = "\n".join(paper_start_list)
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
paper_zhong = paper_zhong.strip("\n")
if len(paper_zhong) > 510:
data_paper_list = str(paper_zhong).split("\n")
start_index = 0
for i in range(len(data_paper_list)):
if "<Start>" in data_paper_list[i]:
start_index = i
break
left_end = 0
right_end = len(data_paper_list) - 1
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = data_paper_list[start_index]
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [data_paper_list[left]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [data_paper_list[right]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510:
break
else:
old_sen = new_sen
paper_zhong = old_sen
# 目标句子在中间预测结果
sentence_list = [paper_zhong]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_title_roberta(**result_on_device)
predicted_class_idx = torch.argmax(logits[0], dim=1).item()
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_title[predicted_class_idx]])
return paper_quanwen_lable_list
def gen_content_cls(content_list):
paper_quanwen_lable_list = []
for index, paper_sen in content_list:
# 视野前后7句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
# print(len(paper_start_list))
# print(len(paper_end_list))
paper_new_start = "\n".join(paper_start_list)
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
# 视野前15句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]]
# print(len(paper_start_list))
paper_new_start = "\n".join(paper_start_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_qian = "\n".join([paper_new_start, paper_object_dangqian])
# 视野后15句
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]]
# print(len(paper_end_list))
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_hou = "\n".join([paper_object_dangqian, paper_new_end])
# 目标句子在中间预测结果
sentence_list = [paper_zhong]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_content_roberta(**result_on_device)
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
sentence_list = [paper_qian]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_content_roberta_no_end(**result_on_device)
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
sentence_list = [paper_hou]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_content_roberta_no_start(**result_on_device)
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
id_2_len = {}
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
if i not in id_2_len:
id_2_len[i] = 1
else:
id_2_len[i] += 1
queding = False
predicted_class_idx = ""
for i in id_2_len:
if id_2_len[i] >= 2:
queding = True
predicted_class_idx = i
break
if queding == False:
predicted_class_idx = 0
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]])
return paper_quanwen_lable_list
def main(content: str):
# 先整理句子,把句子整理成模型需要的格式 [id, sen, lable]
paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))]
# 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来,
zong_list = gen_zong_cls(paper_content_list)
# 把标题数据和正文数据,无用类别数据做区分
title_data = []
content_data = []
for data_dan in zong_list:
if data_dan[2] == "标题":
title_data.append([data_dan[0], data_dan[1]])
if data_dan[2] == "正文":
content_data.append([data_dan[0], data_dan[1]])
# 把所有的标题类型提取出来,对每个标题区分标题级别
title_list = gen_title_cls(title_data)
# 把所有的正文类别提取出来,逐个进行打标
content_list = gen_content_cls(content_data)
paper_content_list_new = title_list + content_list
# 综合排序
paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0])
paper_content_info_list = []
for data_dan_info in paper_content_list_new:
paper_content_info_list.append({
"index": data_dan_info[0],
"sentence": data_dan_info[1],
"lable" : data_dan_info[2]
})
return paper_content_info_list
@app.route("/predict", methods=["POST"])
def search():
print(request.remote_addr)
content = request.json["content"]
response = main(content)
return jsonify(response) # 返回结果
if __name__ == "__main__":
app.run(host="0.0.0.0", port=28100, threaded=True, debug=False)