Browse Source

第一个正式版本

master
majiahui@haimaqingfan.com 2 days ago
parent
commit
cbb165dfa9
  1. 444
      flask_api.py

444
flask_api.py

@ -31,6 +31,7 @@ import transformers
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoModelForTokenClassification,
AutoTokenizer,
)
import torch
@ -78,6 +79,15 @@ for i in lable_2_id_fenji:
if lable_2_id_fenji[i] not in id_2_lable_fenji:
id_2_lable_fenji[lable_2_id_fenji[i]] = i
# lable_2_id_title = {
# "一级标题": 0,
# "二级标题": 1,
# "三级标题": 2,
# "中文摘要标题": 3,
# "致谢标题": 4,
# "英文摘要标题": 5,
# "参考文献标题": 6
# }
lable_2_id_title = {
"一级标题": 0,
"二级标题": 1,
@ -85,7 +95,9 @@ lable_2_id_title = {
"中文摘要标题": 3,
"致谢标题": 4,
"英文摘要标题": 5,
"参考文献标题": 6
"参考文献标题": 6,
"四级标题": 7,
"非标题类型": 8
}
id_2_lable_title = {}
@ -105,20 +117,31 @@ lable_2_id_content = {
"参考文献": 7
}
id_2_lable_content = {}
for i in lable_2_id_content:
if lable_2_id_content[i] not in id_2_lable_content:
id_2_lable_content[lable_2_id_content[i]] = i
lable_2_id_title_no_title = {
"正文": 0,
"标题": 1
}
id_2_lable_title_no_title = {}
for i in lable_2_id_title_no_title:
if lable_2_id_title_no_title[i] not in id_2_lable_title_no_title:
id_2_lable_title_no_title[lable_2_id_title_no_title[i]] = i
tokenizer = AutoTokenizer.from_pretrained(
"data_zong_roberta",
"data_zong_shout_3",
use_fast=True,
revision="main",
trust_remote_code=False,
)
model_name = "data_zong_roberta"
model_name = "data_zong_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
@ -133,7 +156,7 @@ model_roberta_zong = AutoModelForSequenceClassification.from_pretrained(
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_roberta_no_start"
model_name = "data_zong_no_start_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
@ -148,7 +171,7 @@ model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_zong_roberta_no_end"
model_name = "data_zong_no_end_shout_3"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_fenji),
@ -163,14 +186,29 @@ model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained(
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_title_roberta"
# model_name = "data_title_roberta"
# config = AutoConfig.from_pretrained(
# model_name,
# num_labels=len(lable_2_id_title),
# revision="main",
# trust_remote_code=False
# )
# model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
# model_name,
# config=config,
# revision="main",
# trust_remote_code=False,
# ignore_mismatched_sizes=False,
# ).to(device)
model_name = "data_content_roberta"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_title),
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
model_content_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
@ -178,14 +216,14 @@ model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta"
model_name = "data_content_roberta_no_end"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta = AutoModelForSequenceClassification.from_pretrained(
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
@ -193,14 +231,14 @@ model_content_roberta = AutoModelForSequenceClassification.from_pretrained(
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta_no_end"
model_name = "data_content_roberta_no_start"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
revision="main",
trust_remote_code=False
)
model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained(
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
@ -208,14 +246,14 @@ model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretraine
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_content_roberta_no_start"
model_name = "data_title_roberta_2"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_content),
num_labels=len(lable_2_id_title),
revision="main",
trust_remote_code=False
)
model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained(
model_title_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
@ -223,73 +261,184 @@ model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrai
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_title_no_title_roberta_2"
config = AutoConfig.from_pretrained(
model_name,
num_labels=len(lable_2_id_title_no_title),
revision="main",
trust_remote_code=False
)
model_data_title_no_title_roberta = AutoModelForSequenceClassification.from_pretrained(
model_name,
config=config,
revision="main",
trust_remote_code=False,
ignore_mismatched_sizes=False,
).to(device)
model_name = "data_title_roberta_ner_2"
tokenizer_ner = AutoTokenizer.from_pretrained(model_name)
model_data_title_roberta_ner = AutoModelForTokenClassification.from_pretrained(model_name)
model_data_title_roberta_ner.eval().to(device)
def gen_zong_cls(content_list):
paper_quanwen_lable_list = []
for index, paper_sen in content_list:
# 视野前后7句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
# paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]]
# paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]]
# print(len(paper_start_list))
# print(len(paper_end_list))
paper_new_start = "\n".join(paper_start_list)
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
# paper_new_start = "\n".join(paper_start_list)
# paper_new_end = "\n".join(paper_end_list)
# paper_object_dangqian = "<Start>" + paper_sen + "<End>"
# paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end])
# 视野前15句
paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]]
# print(len(paper_start_list))
paper_new_start = "\n".join(paper_start_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_qian = "\n".join([paper_new_start, paper_object_dangqian])
start_index = index
left_end = 0
right_end = len(content_list) - 1
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = "<Start>" + paper_sen[:30] + "<End>"
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [content_list[left][1][:30]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [content_list[right][1][:30]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510 or left_end_bool == False or right_end_bool == False:
break
else:
old_sen = new_sen
len_sen = len(old_sen.split("\n"))
sentence_zong_zhong = [old_sen, len_sen]
# 视野后15句
paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]]
# print(len(paper_end_list))
paper_new_end = "\n".join(paper_end_list)
paper_object_dangqian = "<Start>" + paper_sen + "<End>"
paper_hou = "\n".join([paper_object_dangqian, paper_new_end])
# 没有后面内容
start_index = index
left_end = 0
right_end = start_index
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = "<Start>" + paper_sen[:30] + "<End>"
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [content_list[left][1][:30]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [content_list[right][1][:30]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510 or left_end_bool == False:
break
else:
old_sen = new_sen
len_sen = len(old_sen.split("\n"))
sentence_zong_no_end = [old_sen, len_sen]
# 没有前面内容
start_index = index
left_end = start_index
right_end = len(content_list) - 1
left = start_index
right = start_index
left_end_bool = True
right_end_bool = True
old_sen = "<Start>" + paper_sen[:30] + "<End>"
while True:
if left - 1 >= left_end:
left = left - 1
else:
left_end_bool = False
if right + 1 <= right_end:
right = right + 1
else:
right_end_bool = False
new_sen_list = [old_sen]
if left_end_bool == True:
new_sen_list = [content_list[left][1][:30]] + new_sen_list
if right_end_bool == True:
new_sen_list = new_sen_list + [content_list[right][1][:30]]
new_sen = "\n".join(new_sen_list)
if len(new_sen) > 510 or right_end_bool == False:
break
else:
old_sen = new_sen
len_sen = len(old_sen.split("\n"))
sentence_zong_no_start = [old_sen, len_sen]
res_score = {}
# 目标句子在中间预测结果
sentence_list = [paper_zhong]
sentence_list = [sentence_zong_zhong[0]]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong(**result_on_device)
predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item()
if predicted_class_idx_zhong not in res_score:
res_score[predicted_class_idx_zhong] = sentence_zong_zhong[1]
else:
res_score[predicted_class_idx_zhong] += sentence_zong_zhong[1]
sentence_list = [paper_qian]
sentence_list = [sentence_zong_no_end[0]]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong_no_end(**result_on_device)
predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item()
if predicted_class_idx_zhong not in res_score:
res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1]
else:
res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1]
sentence_list = [paper_hou]
sentence_list = [sentence_zong_no_start[0]]
# sentence_list = [data[1][0]]
result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt")
result_on_device = {key: value.to(device) for key, value in result.items()}
logits = model_roberta_zong_no_start(**result_on_device)
predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item()
id_2_len = {}
for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]:
if i not in id_2_len:
id_2_len[i] = 1
if predicted_class_idx_zhong not in res_score:
res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1]
else:
id_2_len[i] += 1
res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1]
queding = False
predicted_class_idx = ""
for i in id_2_len:
if id_2_len[i] >= 2:
queding = True
predicted_class_idx = i
break
res_score_list = sorted(res_score.items(), key=lambda item: item[1], reverse=True)
predicted_class_idx = res_score_list[0][0]
# 添加标题规则,按照长度划分
if predicted_class_idx == 0 and len(paper_sen) > 60:
predicted_class_idx = 1
if queding == False:
predicted_class_idx = 0
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]])
return paper_quanwen_lable_list
@ -426,7 +575,103 @@ def gen_content_cls(content_list):
paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]])
return paper_quanwen_lable_list
def split_lists_recursive(a, b, a_soc, b_soc, target_size=510, result_a=None, result_b=None):
"""
递归地同时分割两个列表保持一一对应关系
每个块尽量接近目标大小最后一个块确保有target_size个元素
Parameters:
a: 第一个列表
b: 第二个列表与a一一对应
target_size: 目标块大小
result_a: 递归使用的a的中间结果
result_b: 递归使用的b的中间结果
"""
if result_a is None:
result_a = []
if result_b is None:
result_b = []
# 验证两个列表长度相同
if len(a) != len(b):
raise ValueError("两个列表长度必须相同")
total_elements = len(a)
# 基本情况:剩余元素小于等于target_size
if total_elements <= target_size:
start = 0 - target_size
a_obj = a_soc[start:]
b_obj = b_soc[start:]
start_i = 0
while True:
if result_a == []:
break
if b_obj[start_i] == "-100":
start_i += 1
break
if start_i == len(a_obj):
break
start_i += 1
if a != []:
result_a.append(a_obj[start_i:])
result_b.append(b_obj[start_i:])
return result_a, result_b
target_size_new = target_size
while True:
if a[target_size_new] == "[SEP]":
break
if target_size_new == 0:
break
target_size_new -= 1
a_current_chunk = a[:target_size_new]
b_current_chunk = b[:target_size_new]
# 剩余部分
# target_size = current_chunk_size
while True:
if b[target_size_new][0] == "B":
break
if target_size_new == len(a):
break
target_size_new += 1
a_remaining = a[target_size_new:]
b_remaining = b[target_size_new:]
if a_current_chunk != []:
result_a.append(a_current_chunk)
result_b.append(b_current_chunk)
# 递归处理剩余部分
return split_lists_recursive(a_remaining, b_remaining, a_soc, b_soc, target_size, result_a, result_b)
def ner_predict(tokens):
inputs = tokenizer(
tokens,
is_split_into_words=True,
return_tensors="pt"
).to(device)
with torch.no_grad():
outputs = model_data_title_roberta_ner(**inputs)
logits = outputs.logits
preds = logits.argmax(dim=-1)[0].tolist()
id2label = model_data_title_roberta_ner.config.id2label
word_ids = inputs.word_ids()
results = []
prev_word_id = None
for pred, word_id in zip(preds, word_ids):
if word_id is None or word_id == prev_word_id:
continue
results.append((tokens[word_id], id2label[pred]))
prev_word_id = word_id
return results
def main(content: str):
@ -434,6 +679,7 @@ def main(content: str):
paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))]
# 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来,
print("先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来")
zong_list = gen_zong_cls(paper_content_list)
# 把标题数据和正文数据,无用类别数据做区分
@ -445,14 +691,106 @@ def main(content: str):
title_data.append([data_dan[0], data_dan[1]])
if data_dan[2] == "正文":
content_data.append([data_dan[0], data_dan[1]])
# 把所有的标题类型提取出来,对每个标题区分标题级别
print("把所有的标题类型提取出来,对每个标题区分标题级别")
data_dan_sen = [i[1] for i in title_data]
data_dan_sen_index = [i[0] for i in title_data]
data_dan_sen_index_new = []
for i, j in zip(data_dan_sen_index, data_dan_sen):
linshi = [i] * len(j)
data_dan_sen_index_new.extend(linshi)
data_dan_sen_index_new.extend(["-100"])
data_dan_sen_new = []
for i in data_dan_sen:
linshi = list(i)
data_dan_sen_new.extend(linshi)
data_dan_sen_new.extend(["\n"])
data_dan_sen_index_new = data_dan_sen_index_new[:-1]
data_dan_sen_new = data_dan_sen_new[:-1]
data_dan_sen_new = ["[SEP]" if item == "\n" else item for item in data_dan_sen_new]
a_return1, b_return1 = split_lists_recursive(data_dan_sen_new, data_dan_sen_index_new, data_dan_sen_new, data_dan_sen_index_new,
target_size=510)
data_zong_train_list = []
for i, j in zip(a_return1, b_return1):
data_zong_train_list.append({
"tokens": i,
"tokens_index": j
})
title_list = []
for i in data_zong_train_list:
dan_data = ner_predict(i["tokens"])
dan_data_new = []
linshi_label = []
linshi_str = []
for j in dan_data:
if j[0] != "[SEP]":
label = j[1][2:]
linshi_label.append(label)
linshi_str.append(j[0])
else:
linshi_label = list(set(linshi_label))
linshi_str = "".join(linshi_str)
dan_data_new.append([linshi_str, linshi_label])
if len(linshi_label) != 1:
baocuo = True
linshi_label = []
linshi_str = []
if linshi_str != []:
linshi_str = "".join(linshi_str)
linshi_label = list(set(linshi_label))
dan_data_new.append([linshi_str, linshi_label])
linshi_label = []
linshi_str = []
# data_dan_sen_index_new = [set(ii)[0] for ii in "".join(i["tokens_index"]).split("-100")]
data_dan_sen_index_new = []
linshi = []
for ii in i["tokens_index"]:
if ii == "-100":
data_dan_sen_index_new.append(list(set(linshi))[0])
linshi = []
else:
linshi.append(ii)
if linshi != []:
data_dan_sen_index_new.append(list(set(linshi))[0])
if len(dan_data_new) == len(data_dan_sen_index_new):
for ii, jj in zip(data_dan_sen_index_new, dan_data_new):
sen = jj[0]
label = jj[1][0]
title_list.append([ii, sen, label])
title_data_dict = {}
for i in title_list:
if i[0] not in title_data_dict:
title_data_dict[i[0]] = [[i[1], i[2]]]
else:
title_data_dict[i[0]] += [[i[1], i[2]]]
print(title_data_dict)
# 把所有的标题类型提取出来,对每个标题区分标题级别
print("把所有的标题类型提取出来,对每个标题区分标题级别")
# 把所有的标题类型提取出来,对每个标题区分标题级别
title_list = gen_title_cls(title_data)
print("把所有的标题类型提取出来,对每个标题区分标题级别")
# title_list = gen_title_cls(title_data)
# 把所有的正文类别提取出来,逐个进行打标
print("把所有的正文类别提取出来,逐个进行打标")
content_list = gen_content_cls(content_data)
paper_content_list_new = title_list + content_list
# 综合排序
print("综合排序")
paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0])
paper_content_info_list = []
@ -467,7 +805,7 @@ def main(content: str):
return paper_content_info_list
@app.route("/predict", methods=["POST"])
def search():
def predict():
print(request.remote_addr)
content = request.json["content"]
response = main(content)

Loading…
Cancel
Save