From cbb165dfa9987fc3cc751561d8657ea6339c99f9 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Mon, 12 Jan 2026 16:02:53 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E4=B8=AA=E6=AD=A3=E5=BC=8F?= =?UTF-8?q?=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flask_api.py | 446 +++++++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 392 insertions(+), 54 deletions(-) diff --git a/flask_api.py b/flask_api.py index 2636f0b..9943fd8 100644 --- a/flask_api.py +++ b/flask_api.py @@ -31,6 +31,7 @@ import transformers from transformers import ( AutoConfig, AutoModelForSequenceClassification, + AutoModelForTokenClassification, AutoTokenizer, ) import torch @@ -78,6 +79,15 @@ for i in lable_2_id_fenji: if lable_2_id_fenji[i] not in id_2_lable_fenji: id_2_lable_fenji[lable_2_id_fenji[i]] = i +# lable_2_id_title = { +# "一级标题": 0, +# "二级标题": 1, +# "三级标题": 2, +# "中文摘要标题": 3, +# "致谢标题": 4, +# "英文摘要标题": 5, +# "参考文献标题": 6 +# } lable_2_id_title = { "一级标题": 0, "二级标题": 1, @@ -85,7 +95,9 @@ lable_2_id_title = { "中文摘要标题": 3, "致谢标题": 4, "英文摘要标题": 5, - "参考文献标题": 6 + "参考文献标题": 6, + "四级标题": 7, + "非标题类型": 8 } id_2_lable_title = {} @@ -105,20 +117,31 @@ lable_2_id_content = { "参考文献": 7 } + id_2_lable_content = {} for i in lable_2_id_content: if lable_2_id_content[i] not in id_2_lable_content: id_2_lable_content[lable_2_id_content[i]] = i +lable_2_id_title_no_title = { + "正文": 0, + "标题": 1 +} + +id_2_lable_title_no_title = {} +for i in lable_2_id_title_no_title: + if lable_2_id_title_no_title[i] not in id_2_lable_title_no_title: + id_2_lable_title_no_title[lable_2_id_title_no_title[i]] = i + tokenizer = AutoTokenizer.from_pretrained( - "data_zong_roberta", + "data_zong_shout_3", use_fast=True, revision="main", trust_remote_code=False, ) -model_name = "data_zong_roberta" +model_name = "data_zong_shout_3" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), @@ -133,7 +156,7 @@ model_roberta_zong = AutoModelForSequenceClassification.from_pretrained( ignore_mismatched_sizes=False, ).to(device) -model_name = "data_zong_roberta_no_start" +model_name = "data_zong_no_start_shout_3" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), @@ -148,7 +171,7 @@ model_roberta_zong_no_start = AutoModelForSequenceClassification.from_pretrained ignore_mismatched_sizes=False, ).to(device) -model_name = "data_zong_roberta_no_end" +model_name = "data_zong_no_end_shout_3" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_fenji), @@ -163,14 +186,29 @@ model_roberta_zong_no_end = AutoModelForSequenceClassification.from_pretrained( ignore_mismatched_sizes=False, ).to(device) -model_name = "data_title_roberta" +# model_name = "data_title_roberta" +# config = AutoConfig.from_pretrained( +# model_name, +# num_labels=len(lable_2_id_title), +# revision="main", +# trust_remote_code=False +# ) +# model_title_roberta = AutoModelForSequenceClassification.from_pretrained( +# model_name, +# config=config, +# revision="main", +# trust_remote_code=False, +# ignore_mismatched_sizes=False, +# ).to(device) + +model_name = "data_content_roberta" config = AutoConfig.from_pretrained( model_name, - num_labels=len(lable_2_id_title), + num_labels=len(lable_2_id_content), revision="main", trust_remote_code=False ) -model_title_roberta = AutoModelForSequenceClassification.from_pretrained( +model_content_roberta = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", @@ -178,14 +216,14 @@ model_title_roberta = AutoModelForSequenceClassification.from_pretrained( ignore_mismatched_sizes=False, ).to(device) -model_name = "data_content_roberta" +model_name = "data_content_roberta_no_end" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_content), revision="main", trust_remote_code=False ) -model_content_roberta = AutoModelForSequenceClassification.from_pretrained( +model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", @@ -193,14 +231,14 @@ model_content_roberta = AutoModelForSequenceClassification.from_pretrained( ignore_mismatched_sizes=False, ).to(device) -model_name = "data_content_roberta_no_end" +model_name = "data_content_roberta_no_start" config = AutoConfig.from_pretrained( model_name, num_labels=len(lable_2_id_content), revision="main", trust_remote_code=False ) -model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretrained( +model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", @@ -208,14 +246,14 @@ model_content_roberta_no_end = AutoModelForSequenceClassification.from_pretraine ignore_mismatched_sizes=False, ).to(device) -model_name = "data_content_roberta_no_start" +model_name = "data_title_roberta_2" config = AutoConfig.from_pretrained( model_name, - num_labels=len(lable_2_id_content), + num_labels=len(lable_2_id_title), revision="main", trust_remote_code=False ) -model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrained( +model_title_roberta = AutoModelForSequenceClassification.from_pretrained( model_name, config=config, revision="main", @@ -223,73 +261,184 @@ model_content_roberta_no_start = AutoModelForSequenceClassification.from_pretrai ignore_mismatched_sizes=False, ).to(device) +model_name = "data_title_no_title_roberta_2" +config = AutoConfig.from_pretrained( + model_name, + num_labels=len(lable_2_id_title_no_title), + revision="main", + trust_remote_code=False +) +model_data_title_no_title_roberta = AutoModelForSequenceClassification.from_pretrained( + model_name, + config=config, + revision="main", + trust_remote_code=False, + ignore_mismatched_sizes=False, +).to(device) + +model_name = "data_title_roberta_ner_2" +tokenizer_ner = AutoTokenizer.from_pretrained(model_name) +model_data_title_roberta_ner = AutoModelForTokenClassification.from_pretrained(model_name) +model_data_title_roberta_ner.eval().to(device) + + + def gen_zong_cls(content_list): paper_quanwen_lable_list = [] for index, paper_sen in content_list: # 视野前后7句 - paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] - paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] + # paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 7, 0):index]] + # paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 8]] # print(len(paper_start_list)) # print(len(paper_end_list)) - paper_new_start = "\n".join(paper_start_list) - paper_new_end = "\n".join(paper_end_list) - paper_object_dangqian = "" + paper_sen + "" - paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) + # paper_new_start = "\n".join(paper_start_list) + # paper_new_end = "\n".join(paper_end_list) + # paper_object_dangqian = "" + paper_sen + "" + # paper_zhong = "\n".join([paper_new_start, paper_object_dangqian, paper_new_end]) + + start_index = index + left_end = 0 + right_end = len(content_list) - 1 + left = start_index + right = start_index + left_end_bool = True + right_end_bool = True + old_sen = "" + paper_sen[:30] + "" + while True: + if left - 1 >= left_end: + left = left - 1 + else: + left_end_bool = False + if right + 1 <= right_end: + right = right + 1 + else: + right_end_bool = False - # 视野前15句 - paper_start_list = [paper_sen[:30] for _, paper_sen in content_list[max(index - 15, 0):index]] - # print(len(paper_start_list)) - paper_new_start = "\n".join(paper_start_list) - paper_object_dangqian = "" + paper_sen + "" - paper_qian = "\n".join([paper_new_start, paper_object_dangqian]) + new_sen_list = [old_sen] + if left_end_bool == True: + new_sen_list = [content_list[left][1][:30]] + new_sen_list + if right_end_bool == True: + new_sen_list = new_sen_list + [content_list[right][1][:30]] - # 视野后15句 - paper_end_list = [paper_sen[:30] for _, paper_sen in content_list[index + 1:index + 16]] - # print(len(paper_end_list)) - paper_new_end = "\n".join(paper_end_list) - paper_object_dangqian = "" + paper_sen + "" - paper_hou = "\n".join([paper_object_dangqian, paper_new_end]) + new_sen = "\n".join(new_sen_list) + if len(new_sen) > 510 or left_end_bool == False or right_end_bool == False: + break + else: + old_sen = new_sen + len_sen = len(old_sen.split("\n")) + sentence_zong_zhong = [old_sen, len_sen] + + + # 没有后面内容 + start_index = index + left_end = 0 + right_end = start_index + left = start_index + right = start_index + left_end_bool = True + right_end_bool = True + old_sen = "" + paper_sen[:30] + "" + while True: + if left - 1 >= left_end: + left = left - 1 + else: + left_end_bool = False + if right + 1 <= right_end: + right = right + 1 + else: + right_end_bool = False + + new_sen_list = [old_sen] + if left_end_bool == True: + new_sen_list = [content_list[left][1][:30]] + new_sen_list + if right_end_bool == True: + new_sen_list = new_sen_list + [content_list[right][1][:30]] + + new_sen = "\n".join(new_sen_list) + if len(new_sen) > 510 or left_end_bool == False: + break + else: + old_sen = new_sen + len_sen = len(old_sen.split("\n")) + sentence_zong_no_end = [old_sen, len_sen] + + # 没有前面内容 + start_index = index + left_end = start_index + right_end = len(content_list) - 1 + left = start_index + right = start_index + left_end_bool = True + right_end_bool = True + old_sen = "" + paper_sen[:30] + "" + while True: + if left - 1 >= left_end: + left = left - 1 + else: + left_end_bool = False + if right + 1 <= right_end: + right = right + 1 + else: + right_end_bool = False + new_sen_list = [old_sen] + if left_end_bool == True: + new_sen_list = [content_list[left][1][:30]] + new_sen_list + if right_end_bool == True: + new_sen_list = new_sen_list + [content_list[right][1][:30]] + + new_sen = "\n".join(new_sen_list) + if len(new_sen) > 510 or right_end_bool == False: + break + else: + old_sen = new_sen + + len_sen = len(old_sen.split("\n")) + sentence_zong_no_start = [old_sen, len_sen] + + res_score = {} # 目标句子在中间预测结果 - sentence_list = [paper_zhong] + sentence_list = [sentence_zong_zhong[0]] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_zong(**result_on_device) predicted_class_idx_zhong = torch.argmax(logits[0], dim=1).item() + if predicted_class_idx_zhong not in res_score: + res_score[predicted_class_idx_zhong] = sentence_zong_zhong[1] + else: + res_score[predicted_class_idx_zhong] += sentence_zong_zhong[1] - sentence_list = [paper_qian] + sentence_list = [sentence_zong_no_end[0]] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_zong_no_end(**result_on_device) predicted_class_idx_qian = torch.argmax(logits[0], dim=1).item() + if predicted_class_idx_zhong not in res_score: + res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1] + else: + res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1] - sentence_list = [paper_hou] + sentence_list = [sentence_zong_no_start[0]] # sentence_list = [data[1][0]] result = tokenizer(sentence_list, padding="max_length", max_length=512, truncation=True, return_tensors="pt") result_on_device = {key: value.to(device) for key, value in result.items()} logits = model_roberta_zong_no_start(**result_on_device) predicted_class_idx_hou = torch.argmax(logits[0], dim=1).item() + if predicted_class_idx_zhong not in res_score: + res_score[predicted_class_idx_zhong] = sentence_zong_no_end[1] + else: + res_score[predicted_class_idx_zhong] += sentence_zong_no_end[1] - id_2_len = {} - for i in [predicted_class_idx_qian, predicted_class_idx_hou, predicted_class_idx_zhong]: - if i not in id_2_len: - id_2_len[i] = 1 - else: - id_2_len[i] += 1 + res_score_list = sorted(res_score.items(), key=lambda item: item[1], reverse=True) + predicted_class_idx = res_score_list[0][0] - queding = False - predicted_class_idx = "" - for i in id_2_len: - if id_2_len[i] >= 2: - queding = True - predicted_class_idx = i - break + # 添加标题规则,按照长度划分 + if predicted_class_idx == 0 and len(paper_sen) > 60: + predicted_class_idx = 1 - if queding == False: - predicted_class_idx = 0 paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_fenji[predicted_class_idx]]) return paper_quanwen_lable_list @@ -426,7 +575,103 @@ def gen_content_cls(content_list): paper_quanwen_lable_list.append([index, paper_sen, id_2_lable_content[predicted_class_idx]]) return paper_quanwen_lable_list - +def split_lists_recursive(a, b, a_soc, b_soc, target_size=510, result_a=None, result_b=None): + """ + 递归地同时分割两个列表,保持一一对应关系 + 每个块尽量接近目标大小,最后一个块确保有target_size个元素 + + Parameters: + a: 第一个列表 + b: 第二个列表,与a一一对应 + target_size: 目标块大小 + result_a: 递归使用的a的中间结果 + result_b: 递归使用的b的中间结果 + """ + if result_a is None: + result_a = [] + if result_b is None: + result_b = [] + + # 验证两个列表长度相同 + if len(a) != len(b): + raise ValueError("两个列表长度必须相同") + + total_elements = len(a) + + # 基本情况:剩余元素小于等于target_size + if total_elements <= target_size: + start = 0 - target_size + a_obj = a_soc[start:] + b_obj = b_soc[start:] + start_i = 0 + while True: + if result_a == []: + break + if b_obj[start_i] == "-100": + start_i += 1 + break + if start_i == len(a_obj): + break + start_i += 1 + if a != []: + result_a.append(a_obj[start_i:]) + result_b.append(b_obj[start_i:]) + return result_a, result_b + + target_size_new = target_size + while True: + if a[target_size_new] == "[SEP]": + break + if target_size_new == 0: + break + target_size_new -= 1 + a_current_chunk = a[:target_size_new] + b_current_chunk = b[:target_size_new] + + # 剩余部分 + # target_size = current_chunk_size + while True: + if b[target_size_new][0] == "B": + break + if target_size_new == len(a): + break + target_size_new += 1 + a_remaining = a[target_size_new:] + b_remaining = b[target_size_new:] + + if a_current_chunk != []: + result_a.append(a_current_chunk) + result_b.append(b_current_chunk) + + # 递归处理剩余部分 + return split_lists_recursive(a_remaining, b_remaining, a_soc, b_soc, target_size, result_a, result_b) + +def ner_predict(tokens): + inputs = tokenizer( + tokens, + is_split_into_words=True, + return_tensors="pt" + ).to(device) + + with torch.no_grad(): + outputs = model_data_title_roberta_ner(**inputs) + + logits = outputs.logits + preds = logits.argmax(dim=-1)[0].tolist() + + id2label = model_data_title_roberta_ner.config.id2label + + word_ids = inputs.word_ids() + results = [] + prev_word_id = None + + for pred, word_id in zip(preds, word_ids): + if word_id is None or word_id == prev_word_id: + continue + results.append((tokens[word_id], id2label[pred])) + prev_word_id = word_id + + return results def main(content: str): @@ -434,6 +679,7 @@ def main(content: str): paper_content_list = [[i,j] for i,j in enumerate(content.split("\n"))] # 先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来, + print("先逐句把每句话是否是标题,是否是正文,是否是无用类别识别出来") zong_list = gen_zong_cls(paper_content_list) # 把标题数据和正文数据,无用类别数据做区分 @@ -445,14 +691,106 @@ def main(content: str): title_data.append([data_dan[0], data_dan[1]]) if data_dan[2] == "正文": content_data.append([data_dan[0], data_dan[1]]) + + # 把所有的标题类型提取出来,对每个标题区分标题级别 + print("把所有的标题类型提取出来,对每个标题区分标题级别") + + data_dan_sen = [i[1] for i in title_data] + data_dan_sen_index = [i[0] for i in title_data] + data_dan_sen_index_new = [] + for i, j in zip(data_dan_sen_index, data_dan_sen): + linshi = [i] * len(j) + data_dan_sen_index_new.extend(linshi) + data_dan_sen_index_new.extend(["-100"]) + + data_dan_sen_new = [] + for i in data_dan_sen: + linshi = list(i) + data_dan_sen_new.extend(linshi) + data_dan_sen_new.extend(["\n"]) + + data_dan_sen_index_new = data_dan_sen_index_new[:-1] + data_dan_sen_new = data_dan_sen_new[:-1] + data_dan_sen_new = ["[SEP]" if item == "\n" else item for item in data_dan_sen_new] + a_return1, b_return1 = split_lists_recursive(data_dan_sen_new, data_dan_sen_index_new, data_dan_sen_new, data_dan_sen_index_new, + target_size=510) + + data_zong_train_list = [] + for i, j in zip(a_return1, b_return1): + data_zong_train_list.append({ + "tokens": i, + "tokens_index": j + }) + + title_list = [] + for i in data_zong_train_list: + dan_data = ner_predict(i["tokens"]) + dan_data_new = [] + linshi_label = [] + linshi_str = [] + for j in dan_data: + if j[0] != "[SEP]": + label = j[1][2:] + linshi_label.append(label) + linshi_str.append(j[0]) + else: + linshi_label = list(set(linshi_label)) + linshi_str = "".join(linshi_str) + dan_data_new.append([linshi_str, linshi_label]) + if len(linshi_label) != 1: + baocuo = True + linshi_label = [] + linshi_str = [] + + if linshi_str != []: + linshi_str = "".join(linshi_str) + linshi_label = list(set(linshi_label)) + dan_data_new.append([linshi_str, linshi_label]) + linshi_label = [] + linshi_str = [] + + # data_dan_sen_index_new = [set(ii)[0] for ii in "".join(i["tokens_index"]).split("-100")] + data_dan_sen_index_new = [] + linshi = [] + for ii in i["tokens_index"]: + if ii == "-100": + data_dan_sen_index_new.append(list(set(linshi))[0]) + linshi = [] + else: + linshi.append(ii) + + if linshi != []: + data_dan_sen_index_new.append(list(set(linshi))[0]) + + if len(dan_data_new) == len(data_dan_sen_index_new): + for ii, jj in zip(data_dan_sen_index_new, dan_data_new): + sen = jj[0] + label = jj[1][0] + title_list.append([ii, sen, label]) + + title_data_dict = {} + for i in title_list: + if i[0] not in title_data_dict: + title_data_dict[i[0]] = [[i[1], i[2]]] + else: + title_data_dict[i[0]] += [[i[1], i[2]]] + + print(title_data_dict) + # 把所有的标题类型提取出来,对每个标题区分标题级别 + print("把所有的标题类型提取出来,对每个标题区分标题级别") + + # 把所有的标题类型提取出来,对每个标题区分标题级别 - title_list = gen_title_cls(title_data) + print("把所有的标题类型提取出来,对每个标题区分标题级别") + # title_list = gen_title_cls(title_data) # 把所有的正文类别提取出来,逐个进行打标 + print("把所有的正文类别提取出来,逐个进行打标") content_list = gen_content_cls(content_data) paper_content_list_new = title_list + content_list # 综合排序 + print("综合排序") paper_content_list_new = sorted(paper_content_list_new, key=lambda item: item[0]) paper_content_info_list = [] @@ -467,7 +805,7 @@ def main(content: str): return paper_content_info_list @app.route("/predict", methods=["POST"]) -def search(): +def predict(): print(request.remote_addr) content = request.json["content"] response = main(content)