排版识别标题级别和正文
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

817 lines
26 KiB

import json
import os
import re
import requests
import time
from flask import Flask, jsonify, Response, request
import pandas as pd
# flask配置
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
# os.environ["WANDB_DISABLED"] = "true"
# 设置CUDA设备
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import logging
import os
import random
import sys
from dataclasses import dataclass, field
from typing import Optional
import datasets
import evaluate
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
)
import torch
from tqdm import tqdm
'''
请求格式
{
"content": "论文正文内容"
}
输出格式
{
"code": 200,
"paper-lable":[
{
"index": 0,
"sentence" : "我参加的是17组的小组学习,主题是关于日本方言。我主要负责参",
lable: "正文"
},
{
"index": 1,
"sentence" : "1.2.1 小组学习",
lable: "三级标题"
},
...
]
}
'''
# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
lable_2_id_fenji = {
"标题": 0,
"正文": 1,
"无用类别": 2
}
id_2_lable_fenji = {}
for i in lable_2_id_fenji:
if lable_2_id_fenji[i] not in id_2_lable_fenji:
id_2_lable_fenji[lable_2_id_fenji[i]] = i
# lable_2_id_title = {
# "一级标题": 0,
# "二级标题": 1,
# "三级标题": 2,
# "中文摘要标题": 3,
# "致谢标题": 4,
# "英文摘要标题": 5,
# "参考文献标题": 6
# }
lable_2_id_title = {
"一级标题": 0,
"二级标题": 1,
"三级标题": 2,
"中文摘要标题": 3,
"致谢标题": 4,
"英文摘要标题": 5,
"参考文献标题": 6,
"四级标题": 7,
"非标题类型": 8
}
id_2_lable_title = {}
for i in lable_2_id_title:
if lable_2_id_title[i] not in id_2_lable_title:
id_2_lable_title[lable_2_id_title[i]] = i
lable_2_id_content = {
"正文": 0,
"英文摘要": 1,
"中文摘要": 2,
"中文关键词": 3,
"英文关键词": 4,
"": 5,
"": 6,
"参考文献": 7
}
id_2_lable_content = {}
for i in lable_2_id_content:
if lable_2_id_content[i] not in id_2_lable_content:
id_2_lable_content[lable_2_id_content[i]] = i
lable_2_id_title_no_title = {
"正文": 0,
"标题": 1
}
id_2_lable_title_no_title = {}
for i in lable_2_id_title_no_title:
if lable_2_id_title_no_title[i] not in id_2_lable_title_no_title:
id_2_lable_title_no_title[lable_2_id_title_no_title[i]] = i
tokenizer = AutoTokenizer.from_pretrained(
"data_zong_shout_3",
use_fast=True,
revision="main",
trust_remote_code=False,
)
model_name = "data_zong_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_zong = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_no_start_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_no_end_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
revision="main",
trust_remote_code=False
)
model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
# model_name = "data_title_roberta"
# config = AutoConfig.from_pretrained(
# model_name,
# num_labels=len(lable_2_id_title),
# revision="main",
# trust_remote_code=False
# )
# model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
# model_name,
# config=config,
# revision="main",
# trust_remote_code=False,
# ignore_mismatched_sizes=False,
# ).to(device)
model_name = "data_content_roberta"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta_no_end"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta_no_start"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_title_roberta_2"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_title),
revision="main",
trust_remote_code=False
)
model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_title_no_title_roberta_2"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_title_no_title),
revision="main",
trust_remote_code=False
)
model_data_title_no_title_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_title_roberta_ner_2"
tokenizer_ner = AutoTokenizer.from_pretrained(model_name)
model_data_title_roberta_ner = AutoModelForTokenClassification.from_pretrained(model_name)
model_data_title_roberta_ner.eval().to(device)
def gen_zong_cls(content_list):
paper_quanwen_lable_list = []
for index, paper_sen in content_list:
# 视野前后7句
# paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
# paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
# print(len(paper_start_list))
# print(len(paper_end_list))
# paper_new_start = "\n".join(paper_start_list)
# paper_new_end = "\n".join(paper_end_list)
# paper_object_dangqian = "<Start>" + paper_sen + "<End>"
# paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
start_index = index
left_end = 0
right_end = len(content_list) - 1
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = "<Start>" + paper_sen[:30] + "<End>"
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [content_list[left][1][:30]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [content_list[right][1][:30]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510 or left_end_bool == False or right_end_bool == False:
break
else:
old_sen = new_sen
len_sen = len(old_sen.split("\n"))
sentence_zong_zhong = [old_sen, len_sen]
# 没有后面内容
start_index = index
left_end = 0
right_end = start_index
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = "<Start>" + paper_sen[:30] + "<End>"
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [content_list[left][1][:30]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [content_list[right][1][:30]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510 or left_end_bool == False:
break
else:
old_sen = new_sen
len_sen = len(old_sen.split("\n"))
sentence_zong_no_end = [old_sen, len_sen]
# 没有前面内容
start_index = index
left_end = start_index
right_end = len(content_list) - 1
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = "<Start>" + paper_sen[:30] + "<End>"
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [content_list[left][1][:30]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [content_list[right][1][:30]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510 or right_end_bool == False:
break
else:
old_sen = new_sen
len_sen = len(old_sen.split("\n"))
sentence_zong_no_start = [old_sen, len_sen]
res_score = {}
# 目标句子在中间预测结果
sentence_list = [sentence_zong_zhong[0]]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong(**result_on_device)
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
if predicted_class_idx_zhong not in res_score:
res_score[predicted_class_idx_zhong] = sentence_zong_zhong[1]
else:
res_score[predicted_class_idx_zhong] += sentence_zong_zhong[1]
sentence_list = [sentence_zong_no_end[0]]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong_no_end(**result_on_device)
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
if predicted_class_idx_zhong not in res_score:
res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1]
else:
res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1]
sentence_list = [sentence_zong_no_start[0]]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong_no_start(**result_on_device)
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
if predicted_class_idx_zhong not in res_score:
res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1]
else:
res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1]
res_score_list = sorted(res_score.items(), key=lambda item: item[1], reverse=True)
predicted_class_idx = res_score_list[0][0]
# 添加标题规则,按照长度划分
if predicted_class_idx == 0 and len(paper_sen) > 60:
predicted_class_idx = 1
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]])
return paper_quanwen_lable_list
def gen_title_cls(content_list):
paper_quanwen_lable_list = []
for index, paper_sen in content_list:
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[0:index]]
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:len(content_list)]]
# print(len(paper_start_list))
# print(len(paper_end_list))
paper_new_start = "\n".join(paper_start_list)
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
paper_zhong = paper_zhong.strip("\n")
if len(paper_zhong) > 510:
data_paper_list = str(paper_zhong).split("\n")
start_index = 0
for i in range(len(data_paper_list)):
if "<Start>" in data_paper_list[i]:
start_index = i
break
left_end = 0
right_end = len(data_paper_list) - 1
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = data_paper_list[start_index]
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [data_paper_list[left]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [data_paper_list[right]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510:
break
else:
old_sen = new_sen
paper_zhong = old_sen
# 目标句子在中间预测结果
sentence_list = [paper_zhong]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_title_roberta(**result_on_device)
predicted_class_idx = torch.argmax(logits[0], dim=1).item()
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_title[predicted_class_idx]])
return paper_quanwen_lable_list
def gen_content_cls(content_list):
paper_quanwen_lable_list = []
for index, paper_sen in content_list:
# 视野前后7句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
# print(len(paper_start_list))
# print(len(paper_end_list))
paper_new_start = "\n".join(paper_start_list)
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
# 视野前15句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]]
# print(len(paper_start_list))
paper_new_start = "\n".join(paper_start_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_qian = "\n".join([paper_new_start, paper_object_dangqian])
# 视野后15句
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]]
# print(len(paper_end_list))
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_hou = "\n".join([paper_object_dangqian, paper_new_end])
# 目标句子在中间预测结果
sentence_list = [paper_zhong]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_content_roberta(**result_on_device)
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
sentence_list = [paper_qian]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_content_roberta_no_end(**result_on_device)
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
sentence_list = [paper_hou]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_content_roberta_no_start(**result_on_device)
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
id_2_len = {}
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
if i not in id_2_len:
id_2_len[i] = 1
else:
id_2_len[i] += 1
queding = False
predicted_class_idx = ""
for i in id_2_len:
if id_2_len[i] >= 2:
queding = True
predicted_class_idx = i
break
if queding == False:
predicted_class_idx = 0
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]])
return paper_quanwen_lable_list
def split_lists_recursive(a, b, a_soc, b_soc, target_size=510, result_a=None, result_b=None):
"""
递归地同时分割两个列表保持一一对应关系
每个块尽量接近目标大小最后一个块确保有target_size个元素
Parameters:
a: 第一个列表
b: 第二个列表与a一一对应
target_size: 目标块大小
result_a: 递归使用的a的中间结果
result_b: 递归使用的b的中间结果
"""
if result_a is None:
result_a = []
if result_b is None:
result_b = []
# 验证两个列表长度相同
if len(a) != len(b):
raise ValueError("两个列表长度必须相同")
total_elements = len(a)
# 基本情况:剩余元素小于等于target_size
if total_elements <= target_size:
start = 0 - target_size
a_obj = a_soc[start:]
b_obj = b_soc[start:]
start_i = 0
while True:
if result_a == []:
break
if b_obj[start_i] == "-100":
start_i += 1
break
if start_i == len(a_obj):
break
start_i += 1
if a != []:
result_a.append(a_obj[start_i:])
result_b.append(b_obj[start_i:])
return result_a, result_b
target_size_new = target_size
while True:
if a[target_size_new] == "[SEP]":
break
if target_size_new == 0:
break
target_size_new -= 1
a_current_chunk = a[:target_size_new]
b_current_chunk = b[:target_size_new]
# 剩余部分
# target_size = current_chunk_size
while True:
if b[target_size_new][0] == "B":
break
if target_size_new == len(a):
break
target_size_new += 1
a_remaining = a[target_size_new:]
b_remaining = b[target_size_new:]
if a_current_chunk != []:
result_a.append(a_current_chunk)
result_b.append(b_current_chunk)
# 递归处理剩余部分
return split_lists_recursive(a_remaining, b_remaining, a_soc, b_soc, target_size, result_a, result_b)
def ner_predict(tokens):
inputs = tokenizer(
tokens,
is_split_into_words=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model_data_title_roberta_ner(**inputs)
logits = outputs.logits
preds = logits.argmax(dim=-1)[0].tolist()
id2label = model_data_title_roberta_ner.config.id2label
word_ids = inputs.word_ids()
results = []
prev_word_id = None
for pred, word_id in zip(preds, word_ids):
if word_id is None or word_id == prev_word_id:
continue
results.append((tokens[word_id], id2label[pred]))
prev_word_id = word_id
return results
def main(content: str):
# 先整理句子,把句子整理成模型需要的格式 [id, sen, lable]
paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))]
# 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来,
print("先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来")
zong_list = gen_zong_cls(paper_content_list)
# 把标题数据和正文数据,无用类别数据做区分
title_data = []
content_data = []
for data_dan in zong_list:
if data_dan[2] == "标题":
title_data.append([data_dan[0], data_dan[1]])
if data_dan[2] == "正文":
content_data.append([data_dan[0], data_dan[1]])
# 把所有的标题类型提取出来,对每个标题区分标题级别
print("把所有的标题类型提取出来,对每个标题区分标题级别")
data_dan_sen = [i[1] for i in title_data]
data_dan_sen_index = [i[0] for i in title_data]
data_dan_sen_index_new = []
for i, j in zip(data_dan_sen_index, data_dan_sen):
linshi = [i] * len(j)
data_dan_sen_index_new.extend(linshi)
data_dan_sen_index_new.extend(["-100"])
data_dan_sen_new = []
for i in data_dan_sen:
linshi = list(i)
data_dan_sen_new.extend(linshi)
data_dan_sen_new.extend(["\n"])
data_dan_sen_index_new = data_dan_sen_index_new[:-1]
data_dan_sen_new = data_dan_sen_new[:-1]
data_dan_sen_new = ["[SEP]" if item == "\n" else item for item in data_dan_sen_new]
a_return1, b_return1 = split_lists_recursive(data_dan_sen_new, data_dan_sen_index_new, data_dan_sen_new, data_dan_sen_index_new,
target_size=510)
data_zong_train_list = []
for i, j in zip(a_return1, b_return1):
data_zong_train_list.append({
"tokens": i,
"tokens_index": j
})
title_list = []
for i in data_zong_train_list:
dan_data = ner_predict(i["tokens"])
dan_data_new = []
linshi_label = []
linshi_str = []
for j in dan_data:
if j[0] != "[SEP]":
label = j[1][2:]
linshi_label.append(label)
linshi_str.append(j[0])
else:
linshi_label = list(set(linshi_label))
linshi_str = "".join(linshi_str)
dan_data_new.append([linshi_str, linshi_label])
if len(linshi_label) != 1:
baocuo = True
linshi_label = []
linshi_str = []
if linshi_str != []:
linshi_str = "".join(linshi_str)
linshi_label = list(set(linshi_label))
dan_data_new.append([linshi_str, linshi_label])
linshi_label = []
linshi_str = []
# data_dan_sen_index_new = [set(ii)[0] for ii in "".join(i["tokens_index"]).split("-100")]
data_dan_sen_index_new = []
linshi = []
for ii in i["tokens_index"]:
if ii == "-100":
data_dan_sen_index_new.append(list(set(linshi))[0])
linshi = []
else:
linshi.append(ii)
if linshi != []:
data_dan_sen_index_new.append(list(set(linshi))[0])
if len(dan_data_new) == len(data_dan_sen_index_new):
for ii, jj in zip(data_dan_sen_index_new, dan_data_new):
sen = jj[0]
label = jj[1][0]
title_list.append([ii, sen, label])
title_data_dict = {}
for i in title_list:
if i[0] not in title_data_dict:
title_data_dict[i[0]] = [[i[1], i[2]]]
else:
title_data_dict[i[0]] += [[i[1], i[2]]]
print(title_data_dict)
# 把所有的标题类型提取出来,对每个标题区分标题级别
print("把所有的标题类型提取出来,对每个标题区分标题级别")
# 把所有的标题类型提取出来,对每个标题区分标题级别
print("把所有的标题类型提取出来,对每个标题区分标题级别")
# title_list = gen_title_cls(title_data)
# 把所有的正文类别提取出来,逐个进行打标
print("把所有的正文类别提取出来,逐个进行打标")
content_list = gen_content_cls(content_data)
paper_content_list_new = title_list + content_list
# 综合排序
print("综合排序")
paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0])
paper_content_info_list = []
for data_dan_info in paper_content_list_new:
paper_content_info_list.append({
"index": data_dan_info[0],
"sentence": data_dan_info[1],
"lable" : data_dan_info[2]
})
return paper_content_info_list
@app.route("/predict", methods=["POST"])
def predict():
print(request.remote_addr)
content = request.json["content"]
response = main(content)
return jsonify(response) # 返回结果
if __name__ == "__main__":
app.run(host="0.0.0.0", port=28100, threaded=True, debug=False)