diff --git a/11篇_xlsx2txt.py b/11篇_xlsx2txt.py index a6348b8..6ddd112 100644 --- a/11篇_xlsx2txt.py +++ b/11篇_xlsx2txt.py @@ -13,11 +13,11 @@ def biaot(biao): biao_len = 8 - len(biao) return biao + " " * biao_len + ":" -biaoti = ["model_1", "model_2", "model_3"] +biaoti = ["model_1", "model_2", "model_3", "model_4", "model_5"] import pandas as pd import os -path = './data/11篇测试excel_汇总_2' +path = './data/11篇测试excel_汇总_4' path_list = [] for file_name in os.listdir(path): path_list.append(file_name) @@ -45,10 +45,19 @@ for file_name in path_list: txt = data_one[3] txt_list.append(biaoti_one + txt) + biaoti_one = biaot(biaoti[3]) + txt = data_one[4] + txt_list.append(biaoti_one + txt) + + biaoti_one = biaot(biaoti[4]) + txt = data_one[5] + txt_list.append(biaoti_one + txt) + + txt_list.append("\n") - with open("./data/11篇测试txt_汇总_1/{}.txt".format(file_name_0), "w", encoding='utf-8') as file: + with open("./data/11篇测试txt_汇总_2/{}.txt".format(file_name_0), "w", encoding='utf-8') as file: for i in txt_list: file.write(i + '\n') file.close() diff --git a/README.md b/README.md index 3027f80..107260f 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,29 @@ # 小说改写项目 -基于simbert模型的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 -训练数据 train_yy_sim.txt +基于unilm模型的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 +训练数据 train_cat_data_4.txt ## 训练 - 训练:bash train.sh - 训练:bash train_dropout.sh + 加入了质量检测训练:bash train.sh + 加入了质量检测训练:bash train_sim.sh ## 预测 - 加入了质量检测 predict_sim.py - + 加入了质量检测 python predict_tf_sim.py + 未加入质量检测 python predict_tf.py ## API serve 目前的启动方式:bash run_app.sh 一键启动方式:bash run_app_gunicorn.sh - 命令行启动:python flask_predict_no_batch.py ## 请求示例 requests.post( "http://192.168.1.17:14000", - json={"texts": ["李正旺你真是傻逼讪笑”。", - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", - "李正旺你真是傻逼讪笑”。", - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。"], - "text_type":"focus"}, + json={"texts": ["张三要爬上高位的,才能够翻云覆雨。"]}, timeout=1000 ) ## 响应 - { - "probabilities": null, - "status_code": true, - "texts": [ - "李正旺你真是傻逼地讪笑。", - "李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。", - "李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。", - "李正旺你真是傻逼地讪笑。", - "李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。", - "李正旺你真是傻逼地讪笑,并挥手说:不不不,你千万不要误会。对于这个事情,校长已经深刻交代过的,而且我也十分理解。你这是见义勇为,是勇斗歹人、义救同学的好人物,在我们清江一中决不能让他流血又流泪!。" - ] -} \ No newline at end of file + {'probabilities': None, 'texts': ['张三要上了巅峰,他就可以为所欲为了。']} \ No newline at end of file diff --git a/data_do/11篇strsim值统计排序.py b/data_do/11篇strsim值统计排序.py new file mode 100644 index 0000000..2ea141f --- /dev/null +++ b/data_do/11篇strsim值统计排序.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/2/27 18:24 +@Author : +@FileName: +@Software: +@Describe: +""" +import pandas as pd +import difflib + +path = "../data/11篇_yy.xlsx" +data = pd.read_excel( + path +).values.tolist() + + +data_new = [] +for i in data: + data_1 = i[0] + data_2 = i[1] + str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() + data_new.append(i + [str_sim_value]) + +data_new = sorted(data_new, key= lambda x:x[2], reverse=True) +df = pd.DataFrame(data_new) +df.to_excel("../data/11篇_yy_strsim.xlsx", index=None) \ No newline at end of file diff --git a/data_do/11篇t5预测strsim排序.py b/data_do/11篇t5预测strsim排序.py new file mode 100644 index 0000000..359fe50 --- /dev/null +++ b/data_do/11篇t5预测strsim排序.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/2/27 18:24 +@Author : +@FileName: +@Software: +@Describe: +""" +import pandas as pd +import difflib + +file = "../data/11篇汇总txt_new_predict_t5.txt" +try: + with open(file, 'r', encoding="utf-8") as f: + lines = [x.strip() for x in f if x.strip() != ''] +except: + with open(file, 'r', encoding="gbk") as f: + lines = [x.strip() for x in f if x.strip() != ''] + +data_new = [] +for i in lines: + data_dan = i.split("\t") + if len(data_dan) != 2: + continue + data_1 = data_dan[0] + data_2 = data_dan[1] + str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() + data_new.append(data_dan + [str_sim_value]) +print(data_new) +data_new = sorted(data_new, key= lambda x:x[2], reverse=True) +df = pd.DataFrame(data_new) +df.to_excel("../data/11篇_t5_strsim.xlsx", index=None) \ No newline at end of file diff --git a/data_do/合并数据.py b/data_do/合并数据.py index 82a6d98..9d71ad2 100644 --- a/data_do/合并数据.py +++ b/data_do/合并数据.py @@ -21,10 +21,11 @@ def read_text(file): if __name__ == '__main__': data = [] - path_list = ["train_yy_sim_10.txt", "train_yy_1_sim_10.txt"] + # path_list = ["train_yy_sim_10.txt", "train_yy_1_sim_10.txt"] + path_list = ["../data/train_yy.txt", "../data/train_yy_1.txt"] for i in path_list: data += read_text(i) - fileName = '../data/train_yy_sim.txt' + fileName = '../data/train_yy_zong.txt' with open(fileName, 'w', encoding='utf-8') as file: for i in data: file.write(str(i) + '\n') diff --git a/data_do/处理11篇yy数据.py b/data_do/处理11篇yy数据.py new file mode 100644 index 0000000..8aa84f9 --- /dev/null +++ b/data_do/处理11篇yy数据.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2022/12/20 10:35 +@Author : +@FileName: +@Software: +@Describe: +""" +import os +from bs4 import BeautifulSoup +import pandas as pd +import re +# 遍历文件夹 + + + +yuanshi = "../data/11篇yy/paperyyreduce20230221120936.html" +soup_source = BeautifulSoup(open(yuanshi, encoding='utf-8'), + "html.parser") + +yyshuju = "../data/11篇yy/paperyyreduce_result20230221120936" +soup_result = BeautifulSoup(open(yyshuju, encoding='utf-8'), + "html.parser") + +source_sentence_list = soup_source.select('p > em') +result_sentence_list = soup_result.select('p > em') + + +data = [] +for sentence_index in range(len(source_sentence_list)): + try: + print(source_sentence_list[sentence_index]["id"]) + print(result_sentence_list[sentence_index]["id"]) + print(result_sentence_list[sentence_index]["class"]) + if source_sentence_list[sentence_index]["id"] == result_sentence_list[sentence_index]["id"] \ + and (result_sentence_list[sentence_index]["class"] == ['similar','red'] + or result_sentence_list[sentence_index]["class"] == ['similar']): + # if source_sentence_list[sentence_index]["id"] == result_sentence_list[sentence_index]["id"]: + source_text = source_sentence_list[sentence_index].string + result_text = result_sentence_list[sentence_index].string + source_text = source_text.strip("\n") + result_text = result_text.strip("\n") + if source_text != None and result_text != None: + data.append([source_text,result_text]) + except: + print(sentence_index) + + # print(data) + + +def data_clean(text): + # 清洗excel中的非法字符,都是不常见的不可显示字符,例如退格,响铃等 + ILLEGAL_CHARACTERS_RE = re.compile(r'[\000-\010]|[\013-\014]|[\016-\037]') + text = ILLEGAL_CHARACTERS_RE.sub(r'', text) + return text + +print(data) +df = pd.DataFrame(data,columns=["原文","yy降重"]) +for col in df.columns: + df[col] = df[col].apply(lambda x: data_clean(x)) + +df.to_excel("../data/11篇_yy.xlsx",index=None) diff --git a/data_do/汇总.py b/data_do/汇总.py index b6a67be..893310f 100644 --- a/data_do/汇总.py +++ b/data_do/汇总.py @@ -12,6 +12,26 @@ import pandas as pd path_1 = '../data/11篇excel' path_2 = "../data/11篇临时拼接" path_3 = "../data/11篇临时拼接2" +path_yy = "../data/11篇_yy_strsim.xlsx" +path_t5 = "../data/11篇_t5_strsim.xlsx" + + +data_yy = pd.read_excel(path_yy).values.tolist() +data_t5 = pd.read_excel(path_t5).values.tolist() +data_yy_dict = {} +data_t5_dict = {} +for i in data_yy: + str_data_yuan = str(i[0]).strip("。").strip() + str_data_lable = str(i[1]).strip("。").strip() + data_yy_dict[str_data_yuan] = str_data_lable + +for i in data_t5: + str_data_yuan = str(i[0]).strip("。").strip() + str_data_lable = str(i[1]).strip("。").strip() + data_t5_dict[str_data_yuan] = str_data_lable + + + path_list = [] for file_name in os.listdir(path_1): path_list.append(file_name) @@ -26,8 +46,18 @@ for file_name in path_list: file_name_ = file_name_0 + "_." + file_name_1 data_3 = pd.read_excel(path_3 + "/" + file_name_).values.tolist() for i in range(len(data_1)): - data_new.append(data_1[i] + [data_2[i][1]] + [data_3[i][1]]) + # print(data_1[i]) + if data_1[i][0] == "。": + continue + + str_data = str(data_1[i][0]).strip() + try: + data_t5_dan = data_t5_dict[str_data] + data_yy_dan = data_yy_dict[str_data] + data_new.append(data_1[i] + [data_2[i][1], data_3[i][1], data_t5_dan, data_yy_dan]) + except: + print(str_data) - df = pd.DataFrame(data_new,columns=["原文","simbert","simbert_datasim07","bertsim_simsim"]) - df.to_excel("../data/11篇测试excel_汇总_1/{}.xlsx".format(file_name_0), index=None) + df = pd.DataFrame(data_new,columns=["原文","simbert","simbert_datasim07","bertsim_simsim","t5","yy"]) + df.to_excel("../data/11篇测试excel_汇总_3/{}.xlsx".format(file_name_0), index=None) diff --git a/data_do/筛选训练数据strsim.py b/data_do/筛选训练数据strsim.py index 789f39e..dfbe374 100644 --- a/data_do/筛选训练数据strsim.py +++ b/data_do/筛选训练数据strsim.py @@ -165,8 +165,6 @@ if __name__ == '__main__': if str_sim_value < 0.70: data_train_text.append("\t".join([data_1, "to", data_2])) - - # eval_list = eval_class.evaluate_t(' '.join(data_1), ' '.join(data_2)) # bleusim_list.append(eval_list[3]) @@ -185,7 +183,6 @@ if __name__ == '__main__': # print(sentence_0_array) # cos_sim = cosine_similarity(sentence_0_array.reshape(1, -1), sentence_1_array.reshape(1, -1)) # word2vecsim_list.append(cos_sim[0][0]) - # bertsim_list = sorted(bertsim_list) # zong_num = len(bertsim_list) # print(bertsim_list) diff --git a/data_do/筛选训练数据层级细分strsim.py b/data_do/筛选训练数据层级细分strsim.py new file mode 100644 index 0000000..a7ef7ff --- /dev/null +++ b/data_do/筛选训练数据层级细分strsim.py @@ -0,0 +1,186 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/1/31 19:02 +@Author : +@FileName: +@Software: +@Describe: +""" +import os +# os.environ["TF_KERAS"] = "1" +import pandas as pd + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +import json +import numpy as np +from bert4keras.backend import keras, set_gelu +from bert4keras.tokenizers import Tokenizer, load_vocab +from bert4keras.models import build_transformer_model +from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr +from bert4keras.snippets import sequence_padding, DataGenerator +from bert4keras.snippets import open +from keras.layers import Lambda, Dense +import tensorflow as tf +from keras.backend import set_session +from sklearn.metrics.pairwise import cosine_similarity +from rouge import Rouge # pip install rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +from tqdm import tqdm +import jieba +from gensim.models import KeyedVectors, word2vec, Word2Vec +import random +import difflib + +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +set_session(tf.Session(config=config)) # 此处不同 + +class Word2vecModel: + def __init__(self): + self.path = "E:\pycharm_workspace\查重分析\word2vec_model\\word2vec_add_new_18.model" + self.model = Word2Vec.load(self.path) + + def word2vec_res(self,seg_0_list, seg_1_list): + sentence_0_list = [] + sentence_1_list = [] + for i in seg_0_list: + a = self.model.wv[i] + sentence_0_list.append(a) + + for i in seg_1_list: + a = self.model.wv[i] + sentence_1_list.append(a) + + return sentence_0_list, sentence_1_list + +class Evaluator(keras.callbacks.Callback): + """评估与保存 + """ + + def __init__(self): + self.rouge = Rouge() + self.smooth = SmoothingFunction().method1 + self.best_bleu = 0. + + # def on_epoch_end(self, epoch, logs=None): + # metrics = self.evaluate(valid_data) # 评测模型 + # if metrics['bleu'] > self.best_bleu: + # self.best_bleu = metrics['bleu'] + # model.save_weights('./best_model.weights') # 保存模型 + # metrics['best_bleu'] = self.best_bleu + # print('valid_data:', metrics) + + + def evaluate_t(self, data_1, data_2, topk=1): + total = 0 + rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0 + + scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2]) + rouge_1 += scores[0]['rouge-1']['f'] + rouge_2 += scores[0]['rouge-2']['f'] + rouge_l += scores[0]['rouge-l']['f'] + bleu += sentence_bleu( + references=[data_1.split(' ')], + hypothesis=data_2.split(' '), + smoothing_function=self.smooth + ) + # rouge_1 /= total + # rouge_2 /= total + # rouge_l /= total + # bleu /= total + return [rouge_1, rouge_2, rouge_l, bleu] + +class bertModel: + def __init__(self): + self.config_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json' + self.checkpoint_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt' + self.dict_path = '../chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt' + self.token_dict, self.keep_tokens = load_vocab( + dict_path=self.dict_path, + simplified=True, + startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'], + ) + self.tokenizer = Tokenizer(self.token_dict, do_lower_case=True) + self.buildmodel() + + + def buildmodel(self): + bert = build_transformer_model( + config_path=self.config_path, + checkpoint_path=self.checkpoint_path, + return_keras_model=False, + ) + + output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) + self.model = keras.models.Model(bert.model.input, output) + self.model.summary() + + def predict(self,text): + batch_token_ids, batch_segment_ids = [], [] + token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256) + batch_token_ids.append(token_ids) + batch_segment_ids.append(segment_ids) + return self.model.predict([batch_token_ids, batch_segment_ids]) + + +def simbert(data_1, data_2): + pass + +def word2vec(): + pass + +def bleu(): + pass + + +if __name__ == '__main__': + file = "../data/train_yy_zong.txt" + sim_value = [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0] + model = bertModel() + eval_class = Evaluator() + # word2vecmodel = Word2vecModel() + try: + with open(file, 'r', encoding="utf-8") as f: + lines = [x.strip() for x in f if x.strip() != ''] + except: + with open(file, 'r', encoding="gbk") as f: + lines = [x.strip() for x in f if x.strip() != ''] + + bertsim_list = [] + bleusim_list = [] + word2vecsim_list = [] + data_train_text = [] + + random.shuffle(lines) + print(len(lines)) + for txt in tqdm(lines): + text = txt.split('\t') + if len(text) == 3: + data_1 = text[0] + data_2 = text[2] + str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() + # if len(data_2) - len(data_1) < 0 and len(data_2) / len(data_1) > 0.8: + # num_yu = 1 - len(data_2) / len(data_1) + # str_sim_value = 1 - str_sim_value * num_yu + + if 1 >= str_sim_value > 0.95: + data_train_text.append([data_1, data_2, str(str_sim_value), "1-0.95"]) + elif 0.95 >= str_sim_value > 0.9: + data_train_text.append([data_1, data_2, str(str_sim_value), "0.95-0.9"]) + elif 0.9 >= str_sim_value > 0.85: + data_train_text.append([data_1, data_2, str(str_sim_value), "0.9-0.85"]) + elif 0.85 >= str_sim_value > 0.8: + data_train_text.append([data_1, data_2, str(str_sim_value), "0.85-0.8"]) + elif 0.8 >= str_sim_value > 0.75: + data_train_text.append([data_1, data_2, str(str_sim_value), "0.8-0.75"]) + elif 0.75 >= str_sim_value > 0.7: + data_train_text.append([data_1, data_2, str(str_sim_value), "0.75-0.7"]) + else: + data_train_text.append([data_1, data_2, str(str_sim_value), "0.7 - 0"]) + + data_train_text = sorted(data_train_text, key=lambda x:x[2], reverse=True) + df = pd.DataFrame(data_train_text) + print(df) + df.to_csv("../data/yy改写相似度.csv", index=None) + df.to_excel("../data/yy改写相似度.xlsx", index=None) diff --git a/flask_predict_no_batch_t5.py b/flask_predict_no_batch_t5.py new file mode 100644 index 0000000..9ac64a0 --- /dev/null +++ b/flask_predict_no_batch_t5.py @@ -0,0 +1,277 @@ +import os +# os.environ["TF_KERAS"] = "1" +# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# os.environ["CUDA_VISIBLE_DEVICES"] = "1" +from flask import Flask, jsonify +from flask import request +# from linshi import autotitle +import requests +from flask import request +from predict_t5 import autotitle + + +import re +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +import logging +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。",",","?","!","…"] + +config = { + "batch_szie": 1000 +} + + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): + fuhao = [",","?","!","…"] + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + for i in range(len(text_1)-1, -1, -1): + if text_1[i] in fuhao: + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, snetence_id, short_num]) + if text_2 != "": + if i+1 != 120: + text_2 = text_1[i+1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(text, snetence_id): + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + for sentence in sentence_list: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence, snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + return sentence_batch_list + + +def paragraph_test_(text:list, text_new:list): + + for i in range(len(text)): + text = chulipangban_test_1(text, i) + text = "。".join(text) + text_new.append(text) + + # text_new_str = "".join(text_new) + return text_new + +def paragraph_test(text:list): + + text_new = [] + for i in range(len(text)): + text_list = chulipangban_test_1(text[i], i) + text_new.extend(text_list) + + # text_new_str = "".join(text_new) + return text_new + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def one_predict(data_text): + ''' + 一个条数据预测 + @param data_text: + @return: + ''' + if data_text[0] != "": + pre_data = autotitle.generate(data_text[0]) + else: + pre_data = "" + data_new = [pre_data] + data_text[1:] + return data_new + + +def predict_data_post_processing(text_list): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + for i in range(len(text_list)): + if text_list[i][2] != 0: + text_list_sentence[-1][0] += text_list[i][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = [] + sentence_one = [] + sentence_id = 0 + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + sentence_id = i[1] + return_list.append("。".join(sentence_one)) + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list.append("。".join(sentence_one)) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list + +def main(text: list): + text_list = paragraph_test(text) + text_list_new = [] + for i in text_list: + pre = one_predict(i) + text_list_new.append(pre) + return_list = predict_data_post_processing(text_list_new) + return return_list + +@app.route('/droprepeat/', methods=['POST']) +def sentence(): + print(request.remote_addr) + texts = request.json["texts"] + text_type = request.json["text_type"] + print("原始语句" + str(texts)) + # question = question.strip('。、!??') + + + if isinstance(texts, list): + texts_list = [] + y_pred_label_list = [] + position_list = [] + + # texts = texts.replace('\'', '\"') + if texts is None: + return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} + return jsonify(return_text) + else: + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + if text_type == 'chapter': + texts_list = main(texts) + + return_text = {"texts": texts_list, "probabilities": None, "status_code": True} + else: + return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} + return jsonify(return_text) + + +# @app.route('/chapter/', methods=['POST']) +# def chapter(): +# texts = request.json["texts"] +# +# print("原始语句" + str(texts)) +# # question = question.strip('。、!??') +# +# if isinstance(texts, str): +# texts_list = [] +# y_pred_label_list = [] +# position_list = [] +# +# # texts = texts.replace('\'', '\"') +# if texts is None: +# return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} +# return jsonify(return_text) +# else: +# texts = texts.split("\n") +# for text in texts: +# text = text.strip() +# return_str = autotitle.generate_random_shortest(text) +# texts_list.append(return_str) +# texts_str = "\n".join(texts_list) +# return_text = {"texts": texts_str, "probabilities": None, "status_code": True} +# else: +# return_text = {"texts": "输入格式应该为str", "probabilities": None, "status_code": False} +# return jsonify(return_text) + + +if __name__ == "__main__": + fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') + logging.basicConfig( + handlers=[fh], + level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%a, %d %b %Y %H:%M:%S', + ) + app.run(host="0.0.0.0", port=14000, threaded=True, debug=False) diff --git a/predict_11pian.py b/predict_11pian.py index aa865eb..c575a3d 100644 --- a/predict_11pian.py +++ b/predict_11pian.py @@ -1079,32 +1079,32 @@ if __name__ == '__main__': # file.close() # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ -# path = './data/11篇txt' -# path_new = './data/11篇model1' -# path_list = [] -# -# for file_name in os.listdir(path): -# path_list.append(file_name) -# for docx_name in path_list: -# df_list_new = [] -# with open(path + "/" + docx_name, 'r', encoding="utf-8") as f: -# lines = [x.strip() for x in f if x.strip() != ''] -# for dan in tqdm(lines): -# break_ = False -# for i in dan: -# if i == "章": -# break_ = True -# break -# if break_ == True: -# df_list_new.append(dan) -# continue -# pre = just_show_sentence([dan]) -# df_list_new.append(pre) -# -# -# -# with open(path_new + "/" + docx_name, "w", encoding='utf-8') as file: -# for i in df_list_new: -# file.write(i + '\n') -# file.close() + path = './data/11篇txt' + path_new = './data/11篇model1' + path_list = [] + + for file_name in os.listdir(path): + path_list.append(file_name) + for docx_name in path_list: + df_list_new = [] + with open(path + "/" + docx_name, 'r', encoding="utf-8") as f: + lines = [x.strip() for x in f if x.strip() != ''] + for dan in tqdm(lines): + break_ = False + for i in dan: + if i == "章": + break_ = True + break + if break_ == True: + df_list_new.append(dan) + continue + pre = just_show_sentence([dan]) + df_list_new.append(pre) + + + + with open(path_new + "/" + docx_name, "w", encoding='utf-8') as file: + for i in df_list_new: + file.write(i + '\n') + file.close() diff --git a/predict_t5.py b/predict_t5.py index 336d8d2..da7146d 100644 --- a/predict_t5.py +++ b/predict_t5.py @@ -85,7 +85,7 @@ class GenerateModel(object): output = CrossEntropy(1)([model.inputs[1], model.outputs[0]]) model = Model(model.inputs, output) - path_model = "output_t5/best_model_t5_dropout_0_3.weights" + path_model = "output_t5/best_model_t5.weights" model.load_weights(path_model) return encoder, decoder, model, tokenizer @@ -104,6 +104,106 @@ class CrossEntropy(Loss): loss = K.sum(loss * y_mask) / K.sum(y_mask) return loss + +class Beamdataone(object): + def __init__(self, num_beams, batch_id, text, end_id, minlen, min_ends, tokenizer, output_ids): + """ + Initialize n-best list of hypotheses. + """ + self.num_beams = num_beams + self.batch_id = batch_id + self.beams = [] + self.minlen = minlen + self.min_ends = min_ends + self.end_id = end_id + self.text = text + self.output_scores = np.zeros(1) + self.output_ids = [output_ids] + self.return_str = "" + self.over = False + self.tokenizer = tokenizer + # self.data() + self.output_str = "" + self.text_2_textids( + self.text + ) + self.scores = np.zeros(1) + self.inputs_vector = 0 + + def text_2_textids(self,text): + token_ids, segment_ids = self.tokenizer.encode(text[0], maxlen=120) + self.text_ids = [token_ids] + + def add_data(self, step, output_scores): + ''' + 还存有的数据,直接可以被迭代, + @param text: + @return: + ''' + # inputs = [np.array([i]) for i in inputs] + # output_ids, output_scores = self.first_output_ids, np.zeros(1) + # + # scores, states = self.predict( + # inputs, output_ids, states, temperature, 'logits' + # ) # 计算当前得分 + # if step == 0: # 第1步预测后将输入重复topk次 + # inputs = [np.repeat(i, self.num_beams, axis=0) for i in self.inputs] + # inputs = [self.token_ids, self.segment_ids] + # inputs = [np.array([i]) for i in inputs] + self.output_ids = np.array(self.output_ids) + if step == 0: # 第1步预测后将输入重复topk次 + self.text_ids = [np.repeat(i, self.num_beams, axis=0) for i in self.text_ids] + scores = output_scores.reshape((-1, 1)) + self.scores # 综合累积得分 + # scores = output_probas + scores = self.output_scores.reshape((-1, 1)) + scores # 综合累积得分 + indices = scores.argpartition(-self.num_beams, axis=None)[-self.num_beams:] # 仅保留topk + indices_1 = indices // scores.shape[1] # 行索引 + indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 + self.output_ids = np.concatenate([self.output_ids[indices_1], indices_2], + 1) # 更新输出 + self.output_scores = np.take_along_axis( + scores, indices, axis=None + ) # 更新得分 + + is_end = self.output_ids[:, -1] == self.end_id # 标记是否以end标记结束 + self.end_counts = (self.output_ids == self.end_id).sum(1) # 统计出现的end标记 + if self.output_ids.shape[1] >= self.minlen: # 最短长度判断 + best = self.output_scores.argmax() # 得分最大的那个 + if is_end[best] and self.end_counts[best] >= self.min_ends: # 如果已经终止 + # return output_ids[best] # 直接输出 + self.return_str_main(self.output_ids, best) + self.over = True + else: # 否则,只保留未完成部分 + flag = ~is_end | (self.end_counts < self.min_ends) # 标记未完成序列 + if not flag.all(): # 如果有已完成的 + self.output_ids = self.output_ids[flag] # 扔掉已完成序列 + self.output_scores = self.output_scores[flag] # 扔掉已完成序列 + self.end_counts = self.end_counts[flag] # 扔掉已完成end计数 + self.num_beams = flag.sum() # topk相应变化 + self.output_ids = self.output_ids.tolist() + self.output_str = [tokenizer.decode(ids) for ids in self.output_ids] + self.text_ids = [self.text_ids[0] for i in range(len(self.output_ids))] + + + # # 达到长度直接输出 + # return output_ids[output_scores.argmax()] + + + # def data(self): + # token_ids, segment_ids = self.tokenizer.encode(self.text, maxlen=256) + # self.token_ids = token_ids + # self.segment_ids = segment_ids + + + # input_str = [text for i in range(self.num_beams)] + # output_str = self.output_str + # return input_str, output_str + + def return_str_main(self, output_ids, best): + output_ids_best = output_ids[best] + self.return_str = self.tokenizer.decode(output_ids_best) + + class AutoTitle(AutoRegressiveDecoder): """seq2seq解码器 """ @@ -149,11 +249,156 @@ class AutoTitle(AutoRegressiveDecoder): nodes = self.last_token(decoder).predict([c_encoded, output_ids]) return nodes + def predict_batch(self, inputs): + # inputs, output_ids, states, temperature, 'probas' + token_ids, output_ids = inputs + # token_ids = np.concatenate([token_ids, output_ids], 1) + # segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) + with graph.as_default(): + K.set_session(sess) + nodes = self.decoder.predict([token_ids, output_ids]) + return nodes - - def generate(self, text, topk=3): + def data_generator(self, token_ids, output_ids): + + batch_token_ids = [] + for i,j in zip(token_ids, output_ids): + + batch_token_ids = sequence_padding(token_ids) + batch_segment_ids = sequence_padding(output_ids) + return batch_token_ids, batch_segment_ids + + def beam_search_batch( + self, + inputs_str, + states=None, + temperature=1, + min_ends=1, + num_beam=3 + ): + """随机采样n个结果 + 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp + 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 + 返回:n个解码序列组成的list。 + """ + output_str = [] + # token_ids, segment_ids = self.data_generator(inputs, output_ids) + batch_nums = len(inputs_str) + return_str_batch = [0] * batch_nums + # output_ids = np.empty((batch_nums, 0), dtype=int) + output_ids = np.array([self.start_id]) + generated = [Beamdataone(num_beam, i, [inputs_str[i]], self.end_id, self.minlen, min_ends, self.tokenizer, output_ids) for i in range(batch_nums)] + # index_data = [i for i in range(batch_nums)] + + c_token_ids = [] + for i in generated: + text_ids = i.text_ids + c_token_ids.extend(text_ids) + c_token_ids = sequence_padding(c_token_ids) + c_encoded = encoder.predict(np.array(c_token_ids)) + + # probas_bool = np.array(token_ids, dtype=bool) + # # np.array(np.where(probas_bool == True)) + # for i, sentence in enumerate(probas_bool): + # lie = np.array(np.where(sentence == True))[0] + # probas_new.append(probas[i, lie[-1]]) + + for i in range(len(generated)): + probas_bool = np.array(generated[i].text_ids[0], dtype=bool) + lie = np.array(np.where(probas_bool == True))[0] + # c_encoded_dan = c_encoded[i, lie[-1]] + c_encoded_dan = c_encoded[np.ix_([i], lie)] + generated[i].inputs_vector = c_encoded_dan[0] + + + for step in range(self.maxlen): + # if step == 0: + # token_ids, segment_ids = self.data_generator(inputs_str, output_str) + # else: + # inputs_str, output_str = [], [] + inputs_vector_batch, output_ids_batch = [], [] + batch_input_num_beam_num = [] + for i in generated: + inputs_vector = i.inputs_vector + # if step != 0: + # output_ids_batch.extend(i.output_ids) + # text_ids_batch.extend(text_ids) + # else: + inputs_vector_batch.append(inputs_vector) + output_ids_batch.extend(i.output_ids) + if step != 0: + batch_input_num_beam_num.append(i.num_beams) + + # token_ids, output_ids_batch = self.data_generator(inputs_vector_batch, output_ids_batch) + + # token_ids_batch = sequence_padding(token_ids_batch) + # segment_ids_batch = sequence_padding(segment_ids_batch) + # output_ids_batch = np.array(output_ids_batch) + # if step == 0: + + inputs = [inputs_vector_batch, output_ids_batch] + + probas = self.predict_batch( + inputs + ) # 计算当前概率 + + probas_new = [] + probas_bool = np.array(inputs_vector_batch, dtype=bool) + # np.array(np.where(probas_bool == True)) + for i, sentence in enumerate(probas_bool): + lie = np.array(np.where(sentence == True))[0] + probas_new.append(probas[i, lie[-1]]) + probas = np.array(probas_new) + + + if step != 0: + num = 0 + if len(generated) > 1: + index = 0 + for index in range(len(batch_input_num_beam_num)-1): + cc = num + num += batch_input_num_beam_num[index] + generated[index].add_data(step, probas[cc:num,:]) + generated[index+1].add_data(step, probas[num:,:]) + else: + generated[0].add_data(step, probas[:,:]) + + else: + for index in range(len(generated)): + generated[index].add_data(step, probas[index,:]) + # i = 0 + # while True: + # bool_ = generated[i].over + # if bool_ == True: + # one_sentence = generated.pop(i) + # return_str_batch[i] = one_sentence.return_str + # if i > len(generated) - 1: + # break + # else: + # i += 1 + # if i > len(generated) - 1: + # break + + generated_new = [] + for i in range(len(generated)): + bool_ = generated[i].over + if bool_ == False: + generated_new.append(generated[i]) + else: + return_str_batch[generated[i].batch_id] = generated[i].return_str + generated = generated_new + + + if generated == []: + return return_str_batch + return return_str_batch + + + def generate(self, text, topk=5): c_token_ids, _ = tokenizer.encode(text, maxlen=120) - c_encoded = encoder.predict(np.array([c_token_ids]))[0] + with graph.as_default(): + K.set_session(sess) + c_encoded = encoder.predict(np.array([c_token_ids]))[0] output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search return tokenizer.decode([int(i) for i in output_ids]) @@ -168,6 +413,9 @@ class AutoTitle(AutoRegressiveDecoder): text.append(tokenizer.decode([int(i) for i in ids])) return text + def generate_beam_search_batch(self, text): + output_str = self.beam_search_batch(text) # 基于随机采样 + return output_str generatemodel = GenerateModel() @@ -185,6 +433,13 @@ def just_show_sentence(file): pre = autotitle.generate(text) return pre + +def just_show_sentence_batch(file: list) -> object: + text = file + pre = autotitle.generate_beam_search_batch(text) + return pre + + if __name__ == '__main__': # file = "train_2842.txt" # just_show(file) @@ -192,17 +447,62 @@ if __name__ == '__main__': # a = just_show_sentence(text) # print(a) # print(type(a)) - is_novel = False - path = "./data/700条论文测试.xlsx" - df_list = pd.read_excel(path).values.tolist() - - - df_list_new = [] - print(len(df_list)) - for i in tqdm(df_list): - pre = just_show_sentence([i[0]]) - - df_list_new.append([i[0], i[1], pre]) - - df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "t5模型"]) - df.to_excel("./data/700条论文测试_7.xlsx", index=None) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # is_novel = False + # path = "./data/700条论文测试.xlsx" + # df_list = pd.read_excel(path).values.tolist() + # + # + # df_list_new = [] + # print(len(df_list)) + # for i in tqdm(df_list): + # pre = just_show_sentence([i[0]]) + # + # df_list_new.append([i[0], i[1], pre]) + # + # df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "t5模型"]) + # df.to_excel("./data/700条论文测试_7.xlsx", index=None) + + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + # import os + # + # file = "./data/11篇汇总txt_new.txt" + # file_t5 = "./data/11篇汇总txt_new_predict_t5.txt" + # + # try: + # with open(file, 'r', encoding="utf-8") as f: + # lines = [x.strip() for x in f if x.strip() != ''] + # except: + # with open(file, 'r', encoding="gbk") as f: + # lines = [x.strip() for x in f if x.strip() != ''] + # + # zishu = 0 + # data = [] + # for i in tqdm(lines): + # + # zishu += len(i) + # pre = just_show_sentence([i]) + # data.append([i, pre]) + # + # with open(file_t5, "w", encoding='utf-8') as file: + # for i in data: + # file.write("\t".join(i) + '\n') + # file.close() + # print(zishu) + + #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + text = ["'李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会", + "历史和当下都证明,创新是民族生存、“发展的不竭源泉”,是是自身发展的必然选择", + "自身发展的必然选择", + "强调轻资产经营, 更加重视经营风险的规避", + "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择", + "是时代对于青年们的深切呼唤"] + # text = ["基本消除“热桥”影响。"] + print(just_show_sentence(text)) + # print(just_show_sentence_top(text)) + # print(just_show_chachong_random(text)) + + # print(tokenizer.encode("\"", maxlen=120)) + # print(just_show_sentence_batch(text)) \ No newline at end of file diff --git a/request_drop.py b/request_drop.py index 1a8303d..99f158e 100644 --- a/request_drop.py +++ b/request_drop.py @@ -47,6 +47,10 @@ ceshi_1 = [ "我" * 110 ] +ceshi_2 = [ + "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。" + ] + jishu = 0 for i in ceshi_1: for j in i: