diff --git a/README.md b/README.md index 1cc4b03..14e372d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,4 @@ # 改写项目 - 基于unilm模型以及t5的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 训练数据 train_yy.txt @@ -27,8 +26,6 @@ python 合并数据.py ## 测试11篇数据 - - ## 测试数据是否有bug python 测试10000篇数据.py \ No newline at end of file diff --git a/cehsi.py b/cehsi.py new file mode 100644 index 0000000..e69de29 diff --git a/ceshishuzi.py b/ceshishuzi.py new file mode 100644 index 0000000..0455cb9 --- /dev/null +++ b/ceshishuzi.py @@ -0,0 +1,9 @@ +def has_numbers(input_string): + return any(char.isdigit() for char in input_string) + +# 示例用法 +input_str = "Hello, 123!" +if has_numbers(input_str): + print("字符串中包含数字") +else: + print("字符串中不包含数字") \ No newline at end of file diff --git a/ceshiyouxiaokuohao.py b/ceshiyouxiaokuohao.py new file mode 100644 index 0000000..06c5898 --- /dev/null +++ b/ceshiyouxiaokuohao.py @@ -0,0 +1,19 @@ +def is_contains_(str): + stack = [] + dict = {"]": "[", "}": "{", ")": "(", "”": "”", "’": "‘", "》": "《"} + for char in str: + if char in dict.values(): + stack.append(char) + elif char in dict.keys(): + if stack == [] or dict[char] != stack.pop(): + return False + else: + continue + if stack == []: + return True + else: + return False + +a = "d(a)a" + +print(is_contains_(a)) \ No newline at end of file diff --git a/config/__init__.py b/config/__init__.py new file mode 100644 index 0000000..796c0b5 --- /dev/null +++ b/config/__init__.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/27 16:51 +@Author : +@FileName: +@Software: +@Describe: +""" diff --git a/config/predict_t5_config.py b/config/predict_t5_config.py index 873209b..eff5338 100644 --- a/config/predict_t5_config.py +++ b/config/predict_t5_config.py @@ -12,7 +12,8 @@ import os pre_model_path = { "t5": { - "linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/mt5/mt5_base", + # "linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/mt5/mt5_base", + "linux": "mt5/mt5_base", "win32": r"E:\pycharm_workspace\premodel\keras\mt5\mt5_base" }, @@ -27,9 +28,9 @@ class DropT5Config: self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000') self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model') self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json') - self.savemodel_path = "./output_t5/best_model_t5_zong_sim_99.weights" + self.savemodel_path = "./output_t5/best_model_t5_0724.weights" self.maxlen = 256 - self.cuda_id = "1" + self.cuda_id = "0" class MultipleResultsDropT5Config: @@ -40,6 +41,6 @@ class MultipleResultsDropT5Config: self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000') self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model') self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json') - self.savemodel_path = "./output_t5/best_model_t5_dropout_0_3.weights" + self.savemodel_path = "./output_t5/best_model_t5_0724.weights" self.maxlen = 256 - self.cuda_id = "1" \ No newline at end of file + self.cuda_id = "0" \ No newline at end of file diff --git a/data_do/11篇t5预测strsim排序.py b/data_do/11篇t5预测strsim排序.py index 359fe50..a33fa98 100644 --- a/data_do/11篇t5预测strsim排序.py +++ b/data_do/11篇t5预测strsim排序.py @@ -10,7 +10,8 @@ import pandas as pd import difflib -file = "../data/11篇汇总txt_new_predict_t5.txt" +# file = "../data/11篇汇总txt_new_predict_t5.txt" +file = "../data/11篇汇总txt_new_predict_t5_0724.txt" try: with open(file, 'r', encoding="utf-8") as f: lines = [x.strip() for x in f if x.strip() != ''] @@ -30,4 +31,4 @@ for i in lines: print(data_new) data_new = sorted(data_new, key= lambda x:x[2], reverse=True) df = pd.DataFrame(data_new) -df.to_excel("../data/11篇_t5_strsim.xlsx", index=None) \ No newline at end of file +df.to_excel("../data/11篇_t5_strsim_0724.xlsx", index=None) \ No newline at end of file diff --git a/data_do/yy训练数据处理.py b/data_do/yy训练数据处理.py index f4e16fe..39a6cde 100644 --- a/data_do/yy训练数据处理.py +++ b/data_do/yy训练数据处理.py @@ -10,8 +10,8 @@ import pandas as pd -path = "../data/论文_yy_小说_3.xlsx" -df_list = pd.read_excel(path).values.tolist() +path = "../data/论文_yy_小说_3.csv" +df_list = pd.read_csv(path).values.tolist() df_list_new = [] print(len(df_list)) @@ -20,7 +20,7 @@ for i in df_list: b = i[1] df_list_new.append("\t".join([a, "to", b])) -with open("../data/train_yy_1.txt", "w", encoding='utf-8') as file: +with open("../data/train_yy_pre.txt", "w", encoding='utf-8') as file: for i in df_list_new: file.write(i + '\n') file.close() diff --git a/data_do/合并数据.py b/data_do/合并数据.py index 9d71ad2..20770ba 100644 --- a/data_do/合并数据.py +++ b/data_do/合并数据.py @@ -22,10 +22,12 @@ if __name__ == '__main__': data = [] # path_list = ["train_yy_sim_10.txt", "train_yy_1_sim_10.txt"] - path_list = ["../data/train_yy.txt", "../data/train_yy_1.txt"] + path_list = ["../data/train_new/train_yy.txt", "../data/train_new/train_yy_1.txt"] for i in path_list: data += read_text(i) - fileName = '../data/train_yy_zong.txt' + + print(len(data)) + fileName = '../data/train_new/train_yy.txt' with open(fileName, 'w', encoding='utf-8') as file: for i in data: file.write(str(i) + '\n') diff --git a/data_do/处理yy数据原始数据.py b/data_do/处理yy数据原始数据.py index 7f0459a..0f639f1 100644 --- a/data_do/处理yy数据原始数据.py +++ b/data_do/处理yy数据原始数据.py @@ -32,7 +32,7 @@ def walkFile(file): # for d in dirs: # print(os.path.join(root, d)) def main(): - walkFile("../data/yy_reduce_data_20221219-20230131") + walkFile("../data/yy_reduce_data_20230210-20230718") main() @@ -41,16 +41,16 @@ data = [] rootpath_list = [] for i in data_path_list: - danpath_list = str(i).split("\\") - rootpath_list.append("\\".join(danpath_list[:-1])) + danpath_list = str(i).split("/") + rootpath_list.append("/".join(danpath_list[:-1])) print(len(rootpath_list)) rootpath_list = list(set(rootpath_list)) for i in tqdm(rootpath_list): try: - soup_source = BeautifulSoup(open("{}\\source".format(i), encoding='utf-8'), + soup_source = BeautifulSoup(open("{}/source".format(i), encoding='utf-8'), "html.parser") - soup_result = BeautifulSoup(open("{}\\result".format(i), encoding='utf-8'), + soup_result = BeautifulSoup(open("{}/result".format(i), encoding='utf-8'), "html.parser") except: continue @@ -84,4 +84,4 @@ df = pd.DataFrame(data,columns=["原文","yy降重"]) for col in df.columns: df[col] = df[col].apply(lambda x: data_clean(x)) -df.to_excel("../data/论文_yy_小说_1.xlsx",index=None) +df.to_csv("../data/论文_yy_小说_1.csv",index=None) diff --git a/data_do/汇总.py b/data_do/汇总.py index 893310f..2b184f9 100644 --- a/data_do/汇总.py +++ b/data_do/汇总.py @@ -14,12 +14,15 @@ path_2 = "../data/11篇临时拼接" path_3 = "../data/11篇临时拼接2" path_yy = "../data/11篇_yy_strsim.xlsx" path_t5 = "../data/11篇_t5_strsim.xlsx" +path_t5_0724 = "../data/11篇_t5_strsim_0724.xlsx" data_yy = pd.read_excel(path_yy).values.tolist() data_t5 = pd.read_excel(path_t5).values.tolist() +data_t5_0724 = pd.read_excel(path_t5_0724).values.tolist() data_yy_dict = {} data_t5_dict = {} +data_t5_dict_0724 = {} for i in data_yy: str_data_yuan = str(i[0]).strip("。").strip() str_data_lable = str(i[1]).strip("。").strip() @@ -29,6 +32,10 @@ for i in data_t5: str_data_yuan = str(i[0]).strip("。").strip() str_data_lable = str(i[1]).strip("。").strip() data_t5_dict[str_data_yuan] = str_data_lable +for i in data_t5_0724: + str_data_yuan = str(i[0]).strip("。").strip() + str_data_lable = str(i[1]).strip("。").strip() + data_t5_dict_0724[str_data_yuan] = str_data_lable @@ -52,12 +59,13 @@ for file_name in path_list: str_data = str(data_1[i][0]).strip() try: + data_t5_0724_dan = data_t5_dict_0724[str_data] data_t5_dan = data_t5_dict[str_data] data_yy_dan = data_yy_dict[str_data] - data_new.append(data_1[i] + [data_2[i][1], data_3[i][1], data_t5_dan, data_yy_dan]) + data_new.append(data_1[i] + [data_2[i][1], data_3[i][1], data_t5_dan, data_t5_0724_dan, data_yy_dan]) except: print(str_data) - df = pd.DataFrame(data_new,columns=["原文","simbert","simbert_datasim07","bertsim_simsim","t5","yy"]) + df = pd.DataFrame(data_new,columns=["原文","simbert","simbert_datasim07","bertsim_simsim","t5","t5-0724", "yy"]) df.to_excel("../data/11篇测试excel_汇总_3/{}.xlsx".format(file_name_0), index=None) diff --git a/data_do/筛选训练数据new.py b/data_do/筛选训练数据new.py new file mode 100644 index 0000000..c093c85 --- /dev/null +++ b/data_do/筛选训练数据new.py @@ -0,0 +1,299 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/1/31 19:02 +@Author : +@FileName: +@Software: +@Describe: +""" +import os +# os.environ["TF_KERAS"] = "1" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +import json +import numpy as np +from bert4keras.backend import keras, set_gelu +from bert4keras.tokenizers import Tokenizer, load_vocab +from bert4keras.models import build_transformer_model +from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr +from bert4keras.snippets import sequence_padding, DataGenerator +from bert4keras.snippets import open +from keras.layers import Lambda, Dense +import tensorflow as tf +from keras.backend import set_session +from sklearn.metrics.pairwise import cosine_similarity +from rouge import Rouge # pip install rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +from tqdm import tqdm +import jieba +from gensim.models import KeyedVectors, word2vec, Word2Vec +import random +import difflib +import re + +config = tf.ConfigProto() +config.gpu_options.allow_growth = True +set_session(tf.Session(config=config)) # 此处不同 + +class Word2vecModel: + def __init__(self): + self.path = "E:\pycharm_workspace\查重分析\word2vec_model\\word2vec_add_new_18.model" + self.model = Word2Vec.load(self.path) + + def word2vec_res(self,seg_0_list, seg_1_list): + sentence_0_list = [] + sentence_1_list = [] + for i in seg_0_list: + a = self.model.wv[i] + sentence_0_list.append(a) + + for i in seg_1_list: + a = self.model.wv[i] + sentence_1_list.append(a) + + return sentence_0_list, sentence_1_list + +class Evaluator(keras.callbacks.Callback): + """评估与保存 + """ + + def __init__(self): + self.rouge = Rouge() + self.smooth = SmoothingFunction().method1 + self.best_bleu = 0. + + # def on_epoch_end(self, epoch, logs=None): + # metrics = self.evaluate(valid_data) # 评测模型 + # if metrics['bleu'] > self.best_bleu: + # self.best_bleu = metrics['bleu'] + # model.save_weights('./best_model.weights') # 保存模型 + # metrics['best_bleu'] = self.best_bleu + # print('valid_data:', metrics) + + + def evaluate_t(self, data_1, data_2, topk=1): + total = 0 + rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0 + + scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2]) + rouge_1 += scores[0]['rouge-1']['f'] + rouge_2 += scores[0]['rouge-2']['f'] + rouge_l += scores[0]['rouge-l']['f'] + bleu += sentence_bleu( + references=[data_1.split(' ')], + hypothesis=data_2.split(' '), + smoothing_function=self.smooth + ) + # rouge_1 /= total + # rouge_2 /= total + # rouge_l /= total + # bleu /= total + return [rouge_1, rouge_2, rouge_l, bleu] + +class bertModel: + def __init__(self): + + # modelpath = "E:\pycharm_workspace\premodel\keras\chinese_simbert_L-12_H-768_A-12" + # modelpath = "E:\pycharm_workspace\premodel\keras\chinese_roberta_wwm_ext_L-12_H-768_A-12" + # modelpath = "E:\pycharm_workspace\premodel\keras\chinese_L-12_H-768_A-12" + modelpath = "/home/majiahui/project/models-llm/keras/chinese_L-12_H-768_A-12" + self.config_path = modelpath + r'/bert_config.json' + self.checkpoint_path = modelpath + r'/bert_model.ckpt' + self.dict_path = modelpath + r'/vocab.txt' + self.token_dict, self.keep_tokens = load_vocab( + dict_path=self.dict_path, + simplified=True, + startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'], + ) + self.tokenizer = Tokenizer(self.token_dict, do_lower_case=True) + self.buildmodel() + + + def buildmodel(self): + bert = build_transformer_model( + config_path=self.config_path, + checkpoint_path=self.checkpoint_path, + return_keras_model=False, + ) + + output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) + self.model = keras.models.Model(bert.model.input, output) + self.model.summary() + + def predict(self,text): + batch_token_ids, batch_segment_ids = [], [] + token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256) + batch_token_ids.append(token_ids) + batch_segment_ids.append(segment_ids) + return self.model.predict([batch_token_ids, batch_segment_ids]) + + def predict_batch(self,text_list): + batch_token_ids, batch_segment_ids = [], [] + + for t in text_list: + token_ids, segment_ids = self.tokenizer.encode(t, maxlen=256) + batch_token_ids.append(token_ids) + batch_segment_ids.append(segment_ids) + + batch_token_ids = sequence_padding(batch_token_ids) + batch_segment_ids = sequence_padding(batch_segment_ids) + return self.model.predict([batch_token_ids, batch_segment_ids]) + +def simbert(data_1, data_2): + pass + +def word2vec(): + pass + +def bleu(): + pass + +def bool_len_strsim(data_1, data_2): + str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() + if len(data_2) - len(data_1) < 0: + if len(data_2) / len(data_1) > 0.8: + num_yu = 1 - len(data_2) / len(data_1) + str_sim_value = 1 - str_sim_value * num_yu + else: + return False, "" + + if str_sim_value < 0.65: + return True, str_sim_value + else: + return False, "" + + +def has_numbers(input_string): + return any(char.isdigit() for char in input_string) + + +def bool_num(data_1, data_2): + bool_1 = has_numbers(data_1) + bool_2 = has_numbers(data_2) + if bool_1 == True and bool_2 == True: + return True + else: + return False + +def is_contains_english(str): + my_re = re.compile(r'[A-Za-z]', re.S) + res = re.findall(my_re, str) + if len(res): + return True + else: + return False + + +def is_contains_kongge(str): + if " " in str or "\t" in str: + return True + else: + return False + +if __name__ == '__main__': + file = "../data/train_yy_pre.txt" + # file = "../data/train_yy_zong_sim_99.txt" + model = bertModel() + eval_class = Evaluator() + data_new = [] + + data_1_list = [] + data_2_list = [] + + # word2vecmodel = Word2vecModel() + try: + with open(file, 'r', encoding="utf-8") as f: + lines = [x.strip() for x in f if x.strip() != ''] + except: + with open(file, 'r', encoding="gbk") as f: + lines = [x.strip() for x in f if x.strip() != ''] + + bertsim_list = [] + bleusim_list = [] + word2vecsim_list = [] + data_train_text = [] + + # random.shuffle(lines) + print(len(lines)) + for txt in tqdm(lines): + + text = txt.split('\t') + if len(text) == 3: + data_1 = text[0] + data_2 = text[2] + + # 判断是否包含数字 + bool_num_ = bool_num(data_1, data_2) + if bool_num_ == False: + continue + + # 判断是否包含英文 + # data_english_bool = is_contains_english(data_1) + # if data_english_bool == True: + # continue + + # 判断是否包含空格 + data_kongge_bool = is_contains_kongge(data_1) + if data_kongge_bool == True: + continue + + # 判断是否符合字符相似度标准 + bool_len_strsim_v, strsim = bool_len_strsim(data_1,data_2) + if bool_len_strsim_v == True: + continue + + # # 第一种方法 + # y1 = model.predict(data_1)[0] + # y2 = model.predict(data_2)[0] + # cos_sim = cosine_similarity(y1.reshape(1, -1), y2.reshape(1, -1)) + # # bertsim_list.append((cos_sim[0][0], strsim, data_1, data_2)) + # if cos_sim[0][0] > 0.9: + # cos_sim_bool = True + # else: + # cos_sim_bool = False + # + # if cos_sim_bool == False: + # continue + # + # data_new.append("\t".join([data_1, "to", data_2])) + + + # data_train_text.append("\t".join([data_1, "to", data_2])) + + # 第二种方法 + y = model.predict_batch([data_1, data_2]) + y1 = y[0] + y2 = y[1] + cos_sim = cosine_similarity(y1.reshape(1, -1), y2.reshape(1, -1)) + # bertsim_list.append((cos_sim[0][0], strsim, data_1, data_2)) + if cos_sim[0][0] > 0.9: + cos_sim_bool = True + else: + cos_sim_bool = False + + if cos_sim_bool == False: + continue + + data_new.append("\t".join([data_1, "to", data_2])) + + + + # bertsim_list.sort(reverse=True) + # with open("../data/tongji_len_strsim_nertsim_1.txt", "w", encoding="utf-8") as f: + # for i in bertsim_list: + # f.write(str(i[0])) + # f.write(str("\t")) + # f.write(str(i[1])) + # f.write(str("\t")) + # f.write(str(i[2])) + # f.write(str("\t")) + # f.write(str(i[3])) + # f.write("\n") + # print(len(data_train_text)) + fileName = '../data/train_new/train_yy_1.txt' + # fileName = '../data/train_new/train_yy.txt' + with open(fileName, 'w', encoding='utf-8') as f: + for i in data_new: + f.write(str(i) + '\n') + f.close() + diff --git a/data_do/筛选训练数据层级细分strsim.py b/data_do/筛选训练数据层级细分strsim.py index a7ef7ff..2b8ca11 100644 --- a/data_do/筛选训练数据层级细分strsim.py +++ b/data_do/筛选训练数据层级细分strsim.py @@ -135,7 +135,7 @@ def bleu(): if __name__ == '__main__': - file = "../data/train_yy_zong.txt" + file = "../data/train_yy_zong_sim_99.txt" sim_value = [1, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0] model = bertModel() eval_class = Evaluator() @@ -182,5 +182,5 @@ if __name__ == '__main__': data_train_text = sorted(data_train_text, key=lambda x:x[2], reverse=True) df = pd.DataFrame(data_train_text) print(df) - df.to_csv("../data/yy改写相似度.csv", index=None) - df.to_excel("../data/yy改写相似度.xlsx", index=None) + df.to_csv("../data/yy改写相似度_1.csv", index=None) + df.to_excel("../data/yy改写相似度_1.xlsx", index=None) diff --git a/data_do/进一步处理降重数据.py b/data_do/进一步处理降重数据.py index dfad0c0..33db2a6 100644 --- a/data_do/进一步处理降重数据.py +++ b/data_do/进一步处理降重数据.py @@ -12,8 +12,8 @@ from tqdm import tqdm import json -path = "../data/论文_yy_小说_1.xlsx" -df_list = pd.read_excel(path).values.tolist() +path = "../data/论文_yy_小说_1.csv" +df_list = pd.read_csv(path).values.tolist() def sentence_do(source,result): @@ -40,4 +40,4 @@ for i in df_list: df_list_new.append([source,result]) df = pd.DataFrame(df_list_new, columns=["原文","yy降重"]) -df.to_excel("../data/论文_yy_小说_3.xlsx",index=None) \ No newline at end of file +df.to_csv("../data/论文_yy_小说_3.csv",index=None) \ No newline at end of file diff --git a/flask_multiple_results.py b/flask_multiple_results.py new file mode 100644 index 0000000..f90de27 --- /dev/null +++ b/flask_multiple_results.py @@ -0,0 +1,293 @@ +import os +from config.predict_t5_config import MultipleResultsDropT5Config +t5config = MultipleResultsDropT5Config() +from config.predict_sim_config import DropSimBertConfig +simbertconfig = DropSimBertConfig() +os.environ["CUDA_VISIBLE_DEVICES"] = t5config.cuda_id +from flask import Flask, jsonify +from flask import request +from predict_t5 import (GenerateModel as T5GenerateModel, + AutoTitle as T5AutoTitle) +from predict_sim import (GenerateModel as SimBertGenerateModel, + AutoTitle as SimBertT5AutoTitle) +import json +from threading import Thread +import time +import re +import requests + + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +import logging +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。",",","?","!","…"] + +t5generatemodel = T5GenerateModel(t5config.config_path, + t5config.checkpoint_path, + t5config.spm_path, + t5config.keep_tokens_path, + t5config.maxlen, + t5config.savemodel_path) + +encoder, decoder, model, tokenizer = t5generatemodel.device_setup() +t5autotitle = T5AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) + +simbertgeneratemodel = SimBertGenerateModel(simbertconfig.config_path, + simbertconfig.checkpoint_path, + simbertconfig.dict_path, + simbertconfig.maxlen, + simbertconfig.savemodel_path) +encoder, seq2seq, tokenizer = simbertgeneratemodel.device_setup() +simbertautotitle = SimBertT5AutoTitle(seq2seq, tokenizer, start_id=None, end_id=tokenizer._token_end_id, maxlen=120) + + +def requests_chatGPT(data): + res = requests.post('http://98.142.138.229:9999/chatgpt', data=data) + return res.json()['res'] + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, chulipangban_return_list, short_num): + fuhao = [",","?","!","…"] + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + if text_2 == "": + chulipangban_return_list.append([text_1, short_num]) + return chulipangban_return_list + for i in range(len(text_1)-1, -1, -1): + if text_1[i] in fuhao: + if i in dialogs_index: + continue + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, short_num]) + if text_2 != "": + if i+1 != 120: + text_2 = text_1[i+1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(text): + # 引号处理 + + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + for dialogs_text_dan in dialogs_text: + text_dan_list = text.split(dialogs_text_dan) + if "。" in dialogs_text_dan: + dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") + text = dialogs_text_dan.join(text_dan_list) + + # text_new_str = "".join(text_new) + + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + for sentence in sentence_list: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + + sentence_split_list = chulichangju_1(sentence,[], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + return sentence_batch_list + + +def paragraph_test(texts:str): + + + text_list = chulipangban_test_1(texts) + + + # text_new_str = "".join(text_new) + return text_list + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def one_predict(data_text): + ''' + 一个条数据预测 + @param data_text: + @return: + ''' + return_data_list = [] + if data_text[0] != "": + data_inputs = data_text[0].replace("&", "。") + prompt_list = ["请帮我改写一下这个句子", "请帮美化一下下面句子", "请帮我修改下面句子让这句话更完美"] + pre_data_list = [] + for i in prompt_list: + pre_data = requests_chatGPT( + data={ + 'prompt':i, + 'text':data_inputs + } + ) + pre_data_list.append(pre_data) + modelclass_list = [t5autotitle, simbertautotitle] + for model in modelclass_list: + pre_data_list.append(model.generate(data_inputs)) + else: + pre_data_list = [""] * 5 + for pre_data in pre_data_list: + return_data_list.append([pre_data] + data_text[1:]) + + return return_data_list + + +def predict_data_post_processing(text_list, index): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + + for i in range(len(text_list)): + if text_list[i][index][2] != 0: + text_list_sentence[-1][0] += text_list[i][index][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = {} + sentence_one = [] + sentence_id = text_list_sentence[0][1] + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + return_list[sentence_id] = "。".join(sentence_one) + sentence_id = i[1] + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list[sentence_id] = "。".join(sentence_one) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list + +def main(text: str): + text_list = paragraph_test(text) + text_list_new = [] + return_list = [] + for i in text_list: + pre_list = one_predict(i) + text_list_new.append(pre_list) + + for index in range(len(text_list_new[0])): + return_list.append(predict_data_post_processing(text_list_new, index)) + return return_list + +@app.route('/multiple_results_droprepeat/', methods=['POST']) +def sentence(): + print(request.remote_addr) + texts = request.json["texts"] + print("原始语句" + str(texts)) + # question = question.strip('。、!??') + + + if isinstance(texts, str): + texts_list = [] + y_pred_label_list = [] + position_list = [] + + # texts = texts.replace('\'', '\"') + if texts is None: + return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} + return jsonify(return_text) + else: + texts_list = main(texts) + return_text = {"texts": texts_list, "probabilities": None, "status_code": True} + else: + return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} + return jsonify(return_text) + + +if __name__ == "__main__": + fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') + logging.basicConfig( + handlers=[fh], + level=logging.DEBUG, + format='%(asctime)s - %(levelname)s - %(message)s', + datefmt='%a, %d %b %Y %H:%M:%S', + ) + app.run(host="0.0.0.0", port=14000, threaded=True, debug=False) diff --git a/flask_predict_no_batch_t5.py b/flask_predict_no_batch_t5.py index 7fc3105..0d92253 100644 --- a/flask_predict_no_batch_t5.py +++ b/flask_predict_no_batch_t5.py @@ -14,8 +14,18 @@ import json from threading import Thread import time import re +import logging + +logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) -pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1) +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=6, password="zhicheng123*") redis_ = redis.Redis(connection_pool=pool, decode_responses=True) db_key_query = 'query' @@ -41,6 +51,25 @@ encoder, decoder, model, tokenizer = generatemodel.device_setup() autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) +class log: + def __init__(self): + pass + + def log(*args, **kwargs): + format = '%Y/%m/%d-%H:%M:%S' + format_h = '%Y-%m-%d' + value = time.localtime(int(time.time())) + dt = time.strftime(format, value) + dt_log_file = time.strftime(format_h, value) + log_file = 'log_file/access-%s' % dt_log_file + ".log" + if not os.path.exists(log_file): + with open(os.path.join(log_file), 'w', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + else: + with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + + def get_dialogs_index(line: str): """ 获取对话及其索引 @@ -290,13 +319,23 @@ def classify(): # 调用模型,设置最大batch_size texts_list = [] return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} - redis_.srem(db_key_querying, query_id) load_result_path = "./new_data_logs/{}.json".format(query_id) + + print("query_id: ", query_id) + print("load_result_path: ", load_result_path) + with open(load_result_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(return_text, f2, ensure_ascii=False, indent=4) - redis_.set(query_id, load_result_path, 28800) + debug_id_1 = 1 + redis_.set(query_id, load_result_path, 86400) + debug_id_2 = 2 + redis_.srem(db_key_querying, query_id) + debug_id_3 = 3 + log.log('start at', + 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( + query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) @app.route("/predict", methods=["POST"]) @@ -309,6 +348,7 @@ def handle_query(): return jsonify(return_text) if isinstance(texts, dict): id_ = str(uuid.uuid1()) # 为query生成唯一标识 + print("uuid: ", uuid) d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id load_request_path = './request_data_logs/{}.json'.format(id_) @@ -329,11 +369,12 @@ t = Thread(target=classify) t.start() if __name__ == "__main__": - fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') - logging.basicConfig( - handlers=[fh], - level=logging.DEBUG, - format='%(asctime)s - %(levelname)s - %(message)s', - datefmt='%a, %d %b %Y %H:%M:%S', - ) + logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) app.run(host="0.0.0.0", port=14000, threaded=True, debug=False) diff --git a/new_data_logs/21f95512-bd79-11ed-8961-4c77cb423b31.json b/new_data_logs/21f95512-bd79-11ed-8961-4c77cb423b31.json deleted file mode 100644 index 58ecffa..0000000 --- a/new_data_logs/21f95512-bd79-11ed-8961-4c77cb423b31.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "texts": { - "0": "李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,", - "1": "李正旺你真是傻逼讪笑,挥手道", - "2": "李正旺你真是傻逼讪笑,挥手道", - "3": "李正旺你真是傻逼讪笑,挥手道", - "4": "李正旺你真是傻逼讪笑,挥手道", - "5": "李正旺你真是傻逼讪笑,挥手道", - "6": "李正旺你真是傻逼讪笑,挥手道", - "7": "李正旺你真是傻逼讪笑,挥手道", - "8": "李正旺你真是傻逼讪笑,挥手李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手" - }, - "probabilities": null, - "status_code": 200 -} \ No newline at end of file diff --git a/predict_11pian.py b/predict_11pian.py index c575a3d..ae380d8 100644 --- a/predict_11pian.py +++ b/predict_11pian.py @@ -1080,7 +1080,7 @@ if __name__ == '__main__': # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ path = './data/11篇txt' - path_new = './data/11篇model1' + path_new = './data/11篇model-0724' path_list = [] for file_name in os.listdir(path): diff --git a/predict_no_batch_1.py b/predict_no_batch_1.py new file mode 100644 index 0000000..2bdfc44 --- /dev/null +++ b/predict_no_batch_1.py @@ -0,0 +1,343 @@ +import os +from config.predict_t5_config import DropT5Config + +config = DropT5Config() +os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id +from flask import Flask, jsonify +from flask import request +# from linshi import autotitle +import requests +from predict_t5 import GenerateModel, AutoTitle +import redis +import uuid +import json +from threading import Thread +import time +import re +import logging + +logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) + +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=100, db=1) +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_querying = 'querying' +db_key_queryset = 'queryset' +batch_size = 32 + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +import logging + +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。", ",", "?", "!", "…"] + +generatemodel = GenerateModel(config.config_path, + config.checkpoint_path, + config.spm_path, + config.keep_tokens_path, + config.maxlen, + config.savemodel_path) +encoder, decoder, model, tokenizer = generatemodel.device_setup() +autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) + + +class log: + def __init__(self): + pass + + def log(*args, **kwargs): + format = '%Y/%m/%d-%H:%M:%S' + format_h = '%Y-%m-%d' + value = time.localtime(int(time.time())) + dt = time.strftime(format, value) + dt_log_file = time.strftime(format_h, value) + log_file = 'log_file/access-%s' % dt_log_file + ".log" + if not os.path.exists(log_file): + with open(os.path.join(log_file), 'w', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + else: + with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): + fuhao = [",", "?", "!", "…"] + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + if text_2 == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + return chulipangban_return_list + for i in range(len(text_1) - 1, -1, -1): + if text_1[i] in fuhao: + if i in dialogs_index: + continue + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, snetence_id, short_num]) + if text_2 != "": + if i + 1 != 120: + text_2 = text_1[i + 1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(snetence_id, text): + # 引号处理 + + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + for dialogs_text_dan in dialogs_text: + text_dan_list = text.split(dialogs_text_dan) + if "。" in dialogs_text_dan: + dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") + text = dialogs_text_dan.join(text_dan_list) + + # text_new_str = "".join(text_new) + + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + for sentence in sentence_list: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence, snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + + sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + return sentence_batch_list + + +def paragraph_test(texts: dict): + text_new = [] + for i, text in texts.items(): + text_list = chulipangban_test_1(i, text) + text_new.extend(text_list) + + # text_new_str = "".join(text_new) + return text_new + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def one_predict(data_text): + ''' + 一个条数据预测 + @param data_text: + @return: + ''' + if data_text[0] != "": + data_inputs = data_text[0].replace("&", "。") + pre_data = autotitle.generate(data_inputs) + else: + pre_data = "" + data_new = [pre_data] + data_text[1:] + return data_new + + +def predict_data_post_processing(text_list): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + + for i in range(len(text_list)): + if text_list[i][2] != 0: + text_list_sentence[-1][0] += text_list[i][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = {} + sentence_one = [] + sentence_id = text_list_sentence[0][1] + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + return_list[sentence_id] = "。".join(sentence_one) + sentence_id = i[1] + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list[sentence_id] = "。".join(sentence_one) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list + +def main(text: dict): + text_list = paragraph_test(text) + text_list_new = [] + for i in text_list: + pre = one_predict(i) + text_list_new.append(pre) + return_list = predict_data_post_processing(text_list_new) + return return_list + + +@app.route('/droprepeat/', methods=['POST']) +def sentence(): + print(request.remote_addr) + texts = request.json["texts"] + text_type = request.json["text_type"] + print("原始语句" + str(texts)) + # question = question.strip('。、!??') + + if isinstance(texts, dict): + texts_list = [] + y_pred_label_list = [] + position_list = [] + + # texts = texts.replace('\'', '\"') + if texts is None: + return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} + return jsonify(return_text) + else: + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + if text_type == 'chapter': + texts_list = main(texts) + return_text = {"texts": texts_list, "probabilities": None, "status_code": True} + else: + return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False} + return jsonify(return_text) + + +def classify(): # 调用模型,设置最大batch_size + while True: + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + time.sleep(3) + continue + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + data_dict_path = json.loads(query) + path = data_dict_path['path'] + # text_type = data_dict["text_type"] + + with open(path, encoding='utf8') as f1: + # 加载文件的对象 + data_dict = json.load(f1) + + query_id = data_dict['id'] + texts = data_dict["text"] + text_type = data_dict["text_type"] + + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + elif text_type == 'chapter': + texts_list = main(texts) + else: + texts_list = [] + + return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} + load_result_path = "./new_data_logs/{}.json".format(query_id) + + print("query_id: ", query_id) + print("load_result_path: ", load_result_path) + + with open(load_result_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(return_text, f2, ensure_ascii=False, indent=4) + debug_id_1 = 1 + redis_.set(query_id, load_result_path, 86400) + debug_id_2 = 2 + redis_.srem(db_key_querying, query_id) + debug_id_3 = 3 + log.log('start at', + 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( + query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) + +if __name__ == '__main__': + classify() + diff --git a/predict_no_batch_2.py b/predict_no_batch_2.py new file mode 100644 index 0000000..2bdfc44 --- /dev/null +++ b/predict_no_batch_2.py @@ -0,0 +1,343 @@ +import os +from config.predict_t5_config import DropT5Config + +config = DropT5Config() +os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id +from flask import Flask, jsonify +from flask import request +# from linshi import autotitle +import requests +from predict_t5 import GenerateModel, AutoTitle +import redis +import uuid +import json +from threading import Thread +import time +import re +import logging + +logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) + +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=100, db=1) +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_querying = 'querying' +db_key_queryset = 'queryset' +batch_size = 32 + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +import logging + +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。", ",", "?", "!", "…"] + +generatemodel = GenerateModel(config.config_path, + config.checkpoint_path, + config.spm_path, + config.keep_tokens_path, + config.maxlen, + config.savemodel_path) +encoder, decoder, model, tokenizer = generatemodel.device_setup() +autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) + + +class log: + def __init__(self): + pass + + def log(*args, **kwargs): + format = '%Y/%m/%d-%H:%M:%S' + format_h = '%Y-%m-%d' + value = time.localtime(int(time.time())) + dt = time.strftime(format, value) + dt_log_file = time.strftime(format_h, value) + log_file = 'log_file/access-%s' % dt_log_file + ".log" + if not os.path.exists(log_file): + with open(os.path.join(log_file), 'w', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + else: + with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): + fuhao = [",", "?", "!", "…"] + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + if text_2 == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + return chulipangban_return_list + for i in range(len(text_1) - 1, -1, -1): + if text_1[i] in fuhao: + if i in dialogs_index: + continue + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, snetence_id, short_num]) + if text_2 != "": + if i + 1 != 120: + text_2 = text_1[i + 1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(snetence_id, text): + # 引号处理 + + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + for dialogs_text_dan in dialogs_text: + text_dan_list = text.split(dialogs_text_dan) + if "。" in dialogs_text_dan: + dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") + text = dialogs_text_dan.join(text_dan_list) + + # text_new_str = "".join(text_new) + + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + for sentence in sentence_list: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence, snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + + sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + return sentence_batch_list + + +def paragraph_test(texts: dict): + text_new = [] + for i, text in texts.items(): + text_list = chulipangban_test_1(i, text) + text_new.extend(text_list) + + # text_new_str = "".join(text_new) + return text_new + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def one_predict(data_text): + ''' + 一个条数据预测 + @param data_text: + @return: + ''' + if data_text[0] != "": + data_inputs = data_text[0].replace("&", "。") + pre_data = autotitle.generate(data_inputs) + else: + pre_data = "" + data_new = [pre_data] + data_text[1:] + return data_new + + +def predict_data_post_processing(text_list): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + + for i in range(len(text_list)): + if text_list[i][2] != 0: + text_list_sentence[-1][0] += text_list[i][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = {} + sentence_one = [] + sentence_id = text_list_sentence[0][1] + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + return_list[sentence_id] = "。".join(sentence_one) + sentence_id = i[1] + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list[sentence_id] = "。".join(sentence_one) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list + +def main(text: dict): + text_list = paragraph_test(text) + text_list_new = [] + for i in text_list: + pre = one_predict(i) + text_list_new.append(pre) + return_list = predict_data_post_processing(text_list_new) + return return_list + + +@app.route('/droprepeat/', methods=['POST']) +def sentence(): + print(request.remote_addr) + texts = request.json["texts"] + text_type = request.json["text_type"] + print("原始语句" + str(texts)) + # question = question.strip('。、!??') + + if isinstance(texts, dict): + texts_list = [] + y_pred_label_list = [] + position_list = [] + + # texts = texts.replace('\'', '\"') + if texts is None: + return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} + return jsonify(return_text) + else: + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + if text_type == 'chapter': + texts_list = main(texts) + return_text = {"texts": texts_list, "probabilities": None, "status_code": True} + else: + return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False} + return jsonify(return_text) + + +def classify(): # 调用模型,设置最大batch_size + while True: + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + time.sleep(3) + continue + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + data_dict_path = json.loads(query) + path = data_dict_path['path'] + # text_type = data_dict["text_type"] + + with open(path, encoding='utf8') as f1: + # 加载文件的对象 + data_dict = json.load(f1) + + query_id = data_dict['id'] + texts = data_dict["text"] + text_type = data_dict["text_type"] + + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + elif text_type == 'chapter': + texts_list = main(texts) + else: + texts_list = [] + + return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} + load_result_path = "./new_data_logs/{}.json".format(query_id) + + print("query_id: ", query_id) + print("load_result_path: ", load_result_path) + + with open(load_result_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(return_text, f2, ensure_ascii=False, indent=4) + debug_id_1 = 1 + redis_.set(query_id, load_result_path, 86400) + debug_id_2 = 2 + redis_.srem(db_key_querying, query_id) + debug_id_3 = 3 + log.log('start at', + 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( + query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) + +if __name__ == '__main__': + classify() + diff --git a/predict_t5.py b/predict_t5.py index a326d33..78a0ced 100644 --- a/predict_t5.py +++ b/predict_t5.py @@ -8,7 +8,11 @@ @Describe: """ #! -*- coding: utf-8 -*- +import os +from config.predict_t5_config import DropT5Config +config = DropT5Config() +os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id import glob from numpy import random random.seed(1001) @@ -410,11 +414,10 @@ class AutoTitle(AutoRegressiveDecoder): return output_str -def just_show_sentence(file): +def just_show_sentence(text): """ - @param file:list + @param text:list """ - text = file[0] pre = autotitle.generate(text) return pre @@ -426,10 +429,7 @@ def just_show_sentence_batch(file: list) -> object: if __name__ == '__main__': - import os - from config.predict_t5_config import DropT5Config - config = DropT5Config() - os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id + generatemodel = GenerateModel(config.config_path, config.checkpoint_path, config.spm_path, @@ -462,31 +462,32 @@ if __name__ == '__main__': # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - # import os - # - # file = "./data/11篇汇总txt_new.txt" - # file_t5 = "./data/11篇汇总txt_new_predict_t5.txt" - # - # try: - # with open(file, 'r', encoding="utf-8") as f: - # lines = [x.strip() for x in f if x.strip() != ''] - # except: - # with open(file, 'r', encoding="gbk") as f: - # lines = [x.strip() for x in f if x.strip() != ''] - # - # zishu = 0 - # data = [] - # for i in tqdm(lines): - # - # zishu += len(i) - # pre = just_show_sentence([i]) - # data.append([i, pre]) - # - # with open(file_t5, "w", encoding='utf-8') as file: - # for i in data: - # file.write("\t".join(i) + '\n') - # file.close() - # print(zishu) + import os + + file = "./data/11篇汇总txt_new.txt" + file_t5 = "./data/11篇汇总txt_new_predict_t5.txt" + file_t5_0724 = "./data/11篇汇总txt_new_predict_t5_0724.txt" + + try: + with open(file, 'r', encoding="utf-8") as f: + lines = [x.strip() for x in f if x.strip() != ''] + except: + with open(file, 'r', encoding="gbk") as f: + lines = [x.strip() for x in f if x.strip() != ''] + + zishu = 0 + data = [] + for i in tqdm(lines): + + zishu += len(i) + pre = just_show_sentence([i]) + data.append([i, pre]) + + with open(file_t5_0724, "w", encoding='utf-8') as file: + for i in data: + file.write("\t".join(i) + '\n') + file.close() + print(zishu) #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ @@ -496,11 +497,13 @@ if __name__ == '__main__': # "强调轻资产经营, 更加重视经营风险的规避", # "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择", # "是时代对于青年们的深切呼唤"] - text = ["随着经济的发展,人们生活水平的提高,环境:问题也日益突出。", - "环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。", - "所以对学生对应用仪器分析解决实际问题的能力要求很高。", - "随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] - print(just_show_sentence(text)) + # text = ["随着经济的发展,人们生活水平的提高,环境:问题也日益突出。", + # "环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。", + # "所以对学生对应用仪器分析解决实际问题的能力要求很高。", + # "随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] + # + # for i in text: + # print(just_show_sentence(i)) # print(just_show_sentence_top(text)) # print(just_show_chachong_random(text)) diff --git a/predict_t5_multiple_results.py b/predict_t5_multiple_results.py new file mode 100644 index 0000000..2ae1427 --- /dev/null +++ b/predict_t5_multiple_results.py @@ -0,0 +1,511 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/1/16 14:59 +@Author : +@FileName: +@Software: +@Describe: +""" +#! -*- coding: utf-8 -*- + +import os +from config.predict_t5_config import MultipleResultsDropT5Config +config = MultipleResultsDropT5Config() +os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id +import glob +from numpy import random +random.seed(1001) +from tqdm import tqdm +import numpy as np +import pandas as pd +import json +from tqdm import tqdm +from bert4keras.backend import keras, K +from bert4keras.layers import Loss +from bert4keras.models import build_transformer_model +from bert4keras.tokenizers import SpTokenizer +from bert4keras.optimizers import Adam +from bert4keras.snippets import sequence_padding, open +from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder +from keras.models import Model +# from rouge import Rouge # pip install rouge +# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +import tensorflow as tf +from keras.backend import set_session +tfconfig = tf.ConfigProto() +tfconfig.gpu_options.allow_growth = True +set_session(tf.Session(config=tfconfig)) # 此处不同 +global graph +graph = tf.get_default_graph() +sess = tf.Session(graph=graph) +set_session(sess) + +# global graph,model +# graph = tf.get_default_graph() +# sess = tf.Session(graph=graph) +# K.set_session(sess) + + +# 基本参数 + +class GenerateModel(object): + def __init__(self): + + self.config_path = config.config_path + self.checkpoint_path = config.checkpoint_path + self.spm_path = config.spm_path + self.keep_tokens_path = config.keep_tokens_path + self.maxlen = config.maxlen + + def device_setup(self): + tokenizer = SpTokenizer(self.spm_path, token_start=None, token_end='') + keep_tokens = json.load(open(self.keep_tokens_path)) + + t5 = build_transformer_model( + config_path=self.config_path, + checkpoint_path=self.checkpoint_path, + keep_tokens=keep_tokens, + model='mt5.1.1', + return_keras_model=False, + name='T5', + ) + + # output = CrossEntropy(2)(model.inputs + model.outputs) + # + # model = Model(model.inputs, output) + encoder = t5.encoder + decoder = t5.decoder + model = t5.model + model.summary() + + output = CrossEntropy(1)([model.inputs[1], model.outputs[0]]) + + model = Model(model.inputs, output) + path_model = config.savemodel_path + model.load_weights(path_model) + + return encoder, decoder, model, tokenizer + + +class CrossEntropy(Loss): + """交叉熵作为loss,并mask掉输入部分 + """ + + def compute_loss(self, inputs, mask=None): + y_true, y_pred = inputs + y_true = y_true[:, 1:] # 目标token_ids + y_mask = K.cast(mask[1], K.floatx())[:, 1:] # 解码器自带mask + y_pred = y_pred[:, :-1] # 预测序列,错开一位 + loss = K.sparse_categorical_crossentropy(y_true, y_pred) + loss = K.sum(loss * y_mask) / K.sum(y_mask) + return loss + + +class Beamdataone(object): + def __init__(self, num_beams, batch_id, text, end_id, minlen, min_ends, tokenizer, output_ids): + """ + Initialize n-best list of hypotheses. + """ + self.num_beams = num_beams + self.batch_id = batch_id + self.beams = [] + self.minlen = minlen + self.min_ends = min_ends + self.end_id = end_id + self.text = text + self.output_scores = np.zeros(1) + self.output_ids = [output_ids] + self.return_str = "" + self.over = False + self.tokenizer = tokenizer + # self.data() + self.output_str = "" + self.text_2_textids( + self.text + ) + self.scores = np.zeros(1) + self.inputs_vector = 0 + + def text_2_textids(self,text): + token_ids, segment_ids = self.tokenizer.encode(text[0], maxlen=120) + self.text_ids = [token_ids] + + def add_data(self, step, output_scores): + ''' + 还存有的数据,直接可以被迭代, + @param text: + @return: + ''' + # inputs = [np.array([i]) for i in inputs] + # output_ids, output_scores = self.first_output_ids, np.zeros(1) + # + # scores, states = self.predict( + # inputs, output_ids, states, temperature, 'logits' + # ) # 计算当前得分 + # if step == 0: # 第1步预测后将输入重复topk次 + # inputs = [np.repeat(i, self.num_beams, axis=0) for i in self.inputs] + # inputs = [self.token_ids, self.segment_ids] + # inputs = [np.array([i]) for i in inputs] + self.output_ids = np.array(self.output_ids) + if step == 0: # 第1步预测后将输入重复topk次 + self.text_ids = [np.repeat(i, self.num_beams, axis=0) for i in self.text_ids] + scores = output_scores.reshape((-1, 1)) + self.scores # 综合累积得分 + # scores = output_probas + scores = self.output_scores.reshape((-1, 1)) + scores # 综合累积得分 + indices = scores.argpartition(-self.num_beams, axis=None)[-self.num_beams:] # 仅保留topk + indices_1 = indices // scores.shape[1] # 行索引 + indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 + self.output_ids = np.concatenate([self.output_ids[indices_1], indices_2], + 1) # 更新输出 + self.output_scores = np.take_along_axis( + scores, indices, axis=None + ) # 更新得分 + + is_end = self.output_ids[:, -1] == self.end_id # 标记是否以end标记结束 + self.end_counts = (self.output_ids == self.end_id).sum(1) # 统计出现的end标记 + if self.output_ids.shape[1] >= self.minlen: # 最短长度判断 + best = self.output_scores.argmax() # 得分最大的那个 + if is_end[best] and self.end_counts[best] >= self.min_ends: # 如果已经终止 + # return output_ids[best] # 直接输出 + self.return_str_main(self.output_ids, best) + self.over = True + else: # 否则,只保留未完成部分 + flag = ~is_end | (self.end_counts < self.min_ends) # 标记未完成序列 + if not flag.all(): # 如果有已完成的 + self.output_ids = self.output_ids[flag] # 扔掉已完成序列 + self.output_scores = self.output_scores[flag] # 扔掉已完成序列 + self.end_counts = self.end_counts[flag] # 扔掉已完成end计数 + self.num_beams = flag.sum() # topk相应变化 + self.output_ids = self.output_ids.tolist() + self.output_str = [tokenizer.decode(ids) for ids in self.output_ids] + self.text_ids = [self.text_ids[0] for i in range(len(self.output_ids))] + + + # # 达到长度直接输出 + # return output_ids[output_scores.argmax()] + + + # def data(self): + # token_ids, segment_ids = self.tokenizer.encode(self.text, maxlen=256) + # self.token_ids = token_ids + # self.segment_ids = segment_ids + + + # input_str = [text for i in range(self.num_beams)] + # output_str = self.output_str + # return input_str, output_str + + def return_str_main(self, output_ids, best): + output_ids_best = output_ids[best] + self.return_str = self.tokenizer.decode(output_ids_best) + + +class AutoTitle(AutoRegressiveDecoder): + """seq2seq解码器 + """ + def __init__(self, encoder, decoder, model, tokenizer, start_id, end_id, maxlen, minlen=1): + super(AutoTitle, self).__init__(start_id, end_id, maxlen, minlen) + self.encoder = encoder + self.decoder = decoder + self.model = model + self.tokenizer = tokenizer + self.start_id = start_id + self.end_id = end_id + self.minlen = minlen + self.models = {} + if start_id is None: + self.first_output_ids = np.empty((1, 0), dtype=int) + else: + self.first_output_ids = np.array([[self.start_id]]) + + # @AutoRegressiveDecoder.wraps(default_rtype='probas') + # def predict(self, inputs, output_ids, states): + # token_ids, segment_ids = inputs + # token_ids = np.concatenate([token_ids, output_ids], 1) + # segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) + # with graph.as_default(): + # K.set_session(sess) + # nodes = self.last_token(self.model).predict([token_ids, segment_ids]) + # return nodes + # # return self.last_token(self.model).predict([token_ids, segment_ids]) + + # @AutoRegressiveDecoder.wraps(default_rtype='probas') + # def predict(self, inputs, output_ids, states): + # c_encoded = inputs[0] + # with graph.as_default(): + # K.set_session(sess) + # nodes = self.last_token(self.decoder).predict([c_encoded, output_ids]) + # return nodes + + @AutoRegressiveDecoder.wraps(default_rtype='probas') + def predict(self, inputs, output_ids, states): + c_encoded = inputs[0] + with graph.as_default(): + K.set_session(sess) + nodes = self.last_token(decoder).predict([c_encoded, output_ids]) + return nodes + + def predict_batch(self, inputs): + # inputs, output_ids, states, temperature, 'probas' + token_ids, output_ids = inputs + # token_ids = np.concatenate([token_ids, output_ids], 1) + # segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) + with graph.as_default(): + K.set_session(sess) + nodes = self.decoder.predict([token_ids, output_ids]) + return nodes + + def data_generator(self, token_ids, output_ids): + + batch_token_ids = [] + for i,j in zip(token_ids, output_ids): + + batch_token_ids = sequence_padding(token_ids) + batch_segment_ids = sequence_padding(output_ids) + return batch_token_ids, batch_segment_ids + + def beam_search_batch( + self, + inputs_str, + states=None, + temperature=1, + min_ends=1, + num_beam=3 + ): + """随机采样n个结果 + 说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp + 表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 + 返回:n个解码序列组成的list。 + """ + output_str = [] + # token_ids, segment_ids = self.data_generator(inputs, output_ids) + batch_nums = len(inputs_str) + return_str_batch = [0] * batch_nums + # output_ids = np.empty((batch_nums, 0), dtype=int) + output_ids = np.array([self.start_id]) + generated = [Beamdataone(num_beam, i, [inputs_str[i]], self.end_id, self.minlen, min_ends, self.tokenizer, output_ids) for i in range(batch_nums)] + # index_data = [i for i in range(batch_nums)] + + c_token_ids = [] + for i in generated: + text_ids = i.text_ids + c_token_ids.extend(text_ids) + c_token_ids = sequence_padding(c_token_ids) + c_encoded = encoder.predict(np.array(c_token_ids)) + + # probas_bool = np.array(token_ids, dtype=bool) + # # np.array(np.where(probas_bool == True)) + # for i, sentence in enumerate(probas_bool): + # lie = np.array(np.where(sentence == True))[0] + # probas_new.append(probas[i, lie[-1]]) + + for i in range(len(generated)): + probas_bool = np.array(generated[i].text_ids[0], dtype=bool) + lie = np.array(np.where(probas_bool == True))[0] + # c_encoded_dan = c_encoded[i, lie[-1]] + c_encoded_dan = c_encoded[np.ix_([i], lie)] + generated[i].inputs_vector = c_encoded_dan[0] + + + for step in range(self.maxlen): + # if step == 0: + # token_ids, segment_ids = self.data_generator(inputs_str, output_str) + # else: + # inputs_str, output_str = [], [] + inputs_vector_batch, output_ids_batch = [], [] + batch_input_num_beam_num = [] + for i in generated: + inputs_vector = i.inputs_vector + # if step != 0: + # output_ids_batch.extend(i.output_ids) + # text_ids_batch.extend(text_ids) + # else: + inputs_vector_batch.append(inputs_vector) + output_ids_batch.extend(i.output_ids) + if step != 0: + batch_input_num_beam_num.append(i.num_beams) + + # token_ids, output_ids_batch = self.data_generator(inputs_vector_batch, output_ids_batch) + + # token_ids_batch = sequence_padding(token_ids_batch) + # segment_ids_batch = sequence_padding(segment_ids_batch) + # output_ids_batch = np.array(output_ids_batch) + # if step == 0: + + inputs = [inputs_vector_batch, output_ids_batch] + + probas = self.predict_batch( + inputs + ) # 计算当前概率 + + probas_new = [] + probas_bool = np.array(inputs_vector_batch, dtype=bool) + # np.array(np.where(probas_bool == True)) + for i, sentence in enumerate(probas_bool): + lie = np.array(np.where(sentence == True))[0] + probas_new.append(probas[i, lie[-1]]) + probas = np.array(probas_new) + + + if step != 0: + num = 0 + if len(generated) > 1: + index = 0 + for index in range(len(batch_input_num_beam_num)-1): + cc = num + num += batch_input_num_beam_num[index] + generated[index].add_data(step, probas[cc:num,:]) + generated[index+1].add_data(step, probas[num:,:]) + else: + generated[0].add_data(step, probas[:,:]) + + else: + for index in range(len(generated)): + generated[index].add_data(step, probas[index,:]) + # i = 0 + # while True: + # bool_ = generated[i].over + # if bool_ == True: + # one_sentence = generated.pop(i) + # return_str_batch[i] = one_sentence.return_str + # if i > len(generated) - 1: + # break + # else: + # i += 1 + # if i > len(generated) - 1: + # break + + generated_new = [] + for i in range(len(generated)): + bool_ = generated[i].over + if bool_ == False: + generated_new.append(generated[i]) + else: + return_str_batch[generated[i].batch_id] = generated[i].return_str + generated = generated_new + + + if generated == []: + return return_str_batch + return return_str_batch + + + def generate(self, text, topk=5): + c_token_ids, _ = tokenizer.encode(text, maxlen=120) + with graph.as_default(): + K.set_session(sess) + c_encoded = encoder.predict(np.array([c_token_ids]))[0] + output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search + return_text = tokenizer.decode([int(i) for i in output_ids]) + return_text = return_text.replace(",", ",") + return return_text + + def generate_random(self, text, n=30, topp=0.9): + c_token_ids, _ = self.tokenizer.encode(text, maxlen=120) + with graph.as_default(): + K.set_session(sess) + c_encoded = self.encoder.predict(np.array([c_token_ids]))[0] + output_ids = self.random_sample([c_encoded], n, topp=topp) # 基于随机采样 + text = [] + for ids in output_ids: + text.append(tokenizer.decode([int(i) for i in ids])) + return text + + def generate_beam_search_batch(self, text): + output_str = self.beam_search_batch(text) # 基于随机采样 + return output_str + + +generatemodel = GenerateModel() +encoder, decoder, model, tokenizer = generatemodel.device_setup() +autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) + + + + +def just_show_sentence(file): + """ + @param file:list + """ + text = file[0] + pre = autotitle.generate(text) + return pre + + +def just_show_sentence_batch(file: list) -> object: + text = file + pre = autotitle.generate_beam_search_batch(text) + return pre + + +if __name__ == '__main__': + # file = "train_2842.txt" + # just_show(file) + # text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是自身发展的必然选择,是时代对于青年们的深切呼唤"] + # a = just_show_sentence(text) + # print(a) + # print(type(a)) + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + # is_novel = False + # path = "./data/700条论文测试.xlsx" + # df_list = pd.read_excel(path).values.tolist() + # + # + # df_list_new = [] + # print(len(df_list)) + # for i in tqdm(df_list): + # pre = just_show_sentence([i[0]]) + # + # df_list_new.append([i[0], i[1], pre]) + # + # df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "t5模型"]) + # df.to_excel("./data/700条论文测试_7.xlsx", index=None) + + # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + import os + + file = "./data/11篇汇总txt_new.txt" + file_t5 = "./data/11篇汇总txt_new_predict_t5.txt" + file_t5_0724 = "./data/11篇汇总txt_new_predict_t5_0724.txt" + + try: + with open(file, 'r', encoding="utf-8") as f: + lines = [x.strip() for x in f if x.strip() != ''] + except: + with open(file, 'r', encoding="gbk") as f: + lines = [x.strip() for x in f if x.strip() != ''] + + zishu = 0 + data = [] + for i in tqdm(lines): + + zishu += len(i) + pre = just_show_sentence([i]) + data.append([i, pre]) + + with open(file_t5_0724, "w", encoding='utf-8') as file: + for i in data: + file.write("\t".join(i) + '\n') + file.close() + print(zishu) + + #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + + # text = ["'李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会", + # "历史和当下都证明,创新是民族生存、“发展的不竭源泉”,是是自身发展的必然选择", + # "自身发展的必然选择", + # "强调轻资产经营, 更加重视经营风险的规避", + # "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择", + # "是时代对于青年们的深切呼唤"] + # text = ["随着经济的发展,人们生活水平的提高,环境问题也日益突出。", + # "环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。", + # "所以对学生对应用仪器分析解决实际问题的能力要求很高。", + # "随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] + # print(just_show_sentence(text)) + # print(just_show_sentence_top(text)) + # print(just_show_chachong_random(text)) + + # print(tokenizer.encode("\"", maxlen=120)) + # print(just_show_sentence_batch(text)) \ No newline at end of file diff --git a/predict_tf_sim.py b/predict_tf_sim.py index a56c279..892dabe 100644 --- a/predict_tf_sim.py +++ b/predict_tf_sim.py @@ -703,7 +703,7 @@ def just_show_csv_beam(file): if __name__ == '__main__': # file = "train_2842.txt" # just_show(file) - text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"] + text = ["随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] just_show_sentence(text) # "简言之,她不好过,李四也别想好过!" # s = "张三的对话" diff --git a/redis_check_uuid.py b/redis_check_uuid.py index 411ea09..e5d3141 100644 --- a/redis_check_uuid.py +++ b/redis_check_uuid.py @@ -28,7 +28,7 @@ from threading import Thread import time app = flask.Flask(__name__) -pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1) +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=6, password="zhicheng123*") redis_ = redis.Redis(connection_pool=pool, decode_responses=True) db_key_query = 'query' @@ -74,6 +74,12 @@ def handle_query(): result_text = {'code': "202", 'text': "", 'probabilities': None} else: result_text = {'code': "203", 'text': "", 'probabilities': None} + load_request_path = './request_data_logs_203/{}.json'.format(id_) + with open(load_request_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(result_text, f2, ensure_ascii=False, indent=4) + return flask.jsonify(result_text) # 返回结果 diff --git a/run_app.sh b/run_app.sh deleted file mode 100644 index 3561691..0000000 --- a/run_app.sh +++ /dev/null @@ -1 +0,0 @@ -nohup python predict_flask.py > myout.file 2>&1 & \ No newline at end of file diff --git a/run_app_flask.sh b/run_app_flask.sh deleted file mode 100644 index ee35f0a..0000000 --- a/run_app_flask.sh +++ /dev/null @@ -1 +0,0 @@ -gunicorn flask_predict_no_batch_t5:app -c gunicorn_config.py \ No newline at end of file diff --git a/task_seq2seq_t5.py b/task_seq2seq_t5.py index 3a8e084..450fae9 100644 --- a/task_seq2seq_t5.py +++ b/task_seq2seq_t5.py @@ -16,7 +16,7 @@ # 补充了评测指标bleu、rouge-1、rouge-2、rouge-l import os # os.environ["TF_KERAS"] = "1" -os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["CUDA_VISIBLE_DEVICES"] = "3" import json import numpy as np from tqdm import tqdm @@ -40,7 +40,7 @@ for gpu in gpus: max_c_len = 128 max_t_len = 128 batch_size = 28 -epochs = 10000 +epochs = 10 # 模型路径 config_path = 'mt5/mt5_base_dropout_0_3_config.json' @@ -49,7 +49,7 @@ spm_path = 'mt5/mt5_base/sentencepiece_cn.model' keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json' -file = "data/train_yy_zong_sim_99.txt" +file = "data/train_new/train_yy.txt" try: with open(file, 'r', encoding="utf-8") as f: lines = [x.strip() for x in f if x.strip() != ''] @@ -205,7 +205,7 @@ class Evaluator(keras.callbacks.Callback): # 保存最优 if logs['loss'] <= self.lowest: self.lowest = logs['loss'] - model.save_weights('./output_t5/best_model_t5_zong_sim_99.weights') + model.save_weights('./output_t5/best_model_t5_0724.weights') # 演示效果7 just_show() diff --git a/测试redis命名.py b/测试redis命名.py new file mode 100644 index 0000000..f68bb8c --- /dev/null +++ b/测试redis命名.py @@ -0,0 +1,9 @@ +import redis + +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=2) +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + + +api_key_list_ip_1 = "api_key_192.168.1.17" +for i in range(10): + redis_.rpush(api_key_list_ip_1, i) \ No newline at end of file diff --git a/请求改写文本.py b/请求改写文本.py new file mode 100644 index 0000000..ed6b27d --- /dev/null +++ b/请求改写文本.py @@ -0,0 +1,37 @@ +import requests + + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=100000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return {} + + +with open("data/drop_weight_data.txt", encoding="utf-8") as f: + text_list = [i for i in f.read().split("\n")] + for i in text_list[:-1]: + text = dialog_line_parse("http://192.168.31.74:19000", {"texts": f"改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{i}"}) + print("原文",i) + print("模型预测", text)