From 2e0d1f28beb1439855fe27214e4d139a0808d85d Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Wed, 8 Mar 2023 18:56:02 +0800 Subject: [PATCH] =?UTF-8?q?V1.0=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 43 +++++---- ceshi10000.py | 250 ------------------------------------------------ 测试10000篇数据.py | 250 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 274 insertions(+), 269 deletions(-) delete mode 100644 ceshi10000.py create mode 100644 测试10000篇数据.py diff --git a/README.md b/README.md index 107260f..f54116a 100644 --- a/README.md +++ b/README.md @@ -1,29 +1,34 @@ -# 小说改写项目 - -基于unilm模型的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 -训练数据 train_cat_data_4.txt +# 改写项目 + + 基于unilm模型以及t5的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 + 训练数据 train_yy.txt ## 训练 - 加入了质量检测训练:bash train.sh - 加入了质量检测训练:bash train_sim.sh + 训练 t5: python task_seq2seq_t5.py + 训练 simbert: python simbert_train.py ## 预测 - - 加入了质量检测 python predict_tf_sim.py - 未加入质量检测 python predict_tf.py + simbert: python predict_sim.py + t5: python predict_t5.py ## API serve + 请求句子uuid服务启动方式:bash run_app_nohub_t5.sh + 根据uuid查找改写结果服务启动方式:bash run_app_nohub_search_redis.sh - 目前的启动方式:bash run_app.sh - 一键启动方式:bash run_app_gunicorn.sh +## 请求响应示例 + 请求句子uuid: https://console-docs.apipost.cn/preview/e3717e390cbdb50e/f4479038c8015f34 + 请求改写结果: https://console-docs.apipost.cn/preview/6b9de12817e8ef08/b158334d2c9534d2 -## 请求示例 - requests.post( - "http://192.168.1.17:14000", - json={"texts": ["张三要爬上高位的,才能够翻云覆雨。"]}, - timeout=1000 - ) +## 从yy数据生成训练数据 + python data_do/yy数据处理.py + python data_do/进一步处理降重数据.py + python data_do/yy训练数据处理.py + python 合并数据.py + python 筛选训练数据strsim.py +## 测试11篇数据 + + -## 响应 - {'probabilities': None, 'texts': ['张三要上了巅峰,他就可以为所欲为了。']} \ No newline at end of file +## 测试数据是否有bug + python 测试10000篇数据.py \ No newline at end of file diff --git a/ceshi10000.py b/ceshi10000.py deleted file mode 100644 index 17e4c80..0000000 --- a/ceshi10000.py +++ /dev/null @@ -1,250 +0,0 @@ -# -*- coding: utf-8 -*- - -""" -@Time : 2023/3/3 14:22 -@Author : -@FileName: -@Software: -@Describe: -""" -import os -# os.environ["TF_KERAS"] = "1" -# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" -from flask import Flask, jsonify -from predict_t5 import autotitle -import re -import json -from tqdm import tqdm - - - -db_key_query = 'query' -db_key_result = 'result' -batch_size = 32 - -app = Flask(__name__) -app.config["JSON_AS_ASCII"] = False - -pattern = r"[。]" -RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") -fuhao_end_sentence = ["。",",","?","!","…"] - -config = { - "batch_szie": 1000 -} - - -def get_dialogs_index(line: str): - """ - 获取对话及其索引 - :param line 文本 - :return dialogs 对话内容 - dialogs_index: 对话位置索引 - other_index: 其他内容位置索引 - """ - dialogs = re.finditer(RE_DIALOG, line) - dialogs_text = re.findall(RE_DIALOG, line) - dialogs_index = [] - for dialog in dialogs: - all_ = [i for i in range(dialog.start(), dialog.end())] - dialogs_index.extend(all_) - other_index = [i for i in range(len(line)) if i not in dialogs_index] - - return dialogs_text, dialogs_index, other_index - - -def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): - fuhao = [",","?","!","…"] - dialogs_text, dialogs_index, other_index = get_dialogs_index(text) - text_1 = text[:120] - text_2 = text[120:] - text_1_new = "" - if text_2 == "": - chulipangban_return_list.append([text_1, snetence_id, short_num]) - return chulipangban_return_list - for i in range(len(text_1)-1, -1, -1): - if text_1[i] in fuhao: - if i in dialogs_index: - continue - text_1_new = text_1[:i] - text_1_new += text_1[i] - chulipangban_return_list.append([text_1_new, snetence_id, short_num]) - if text_2 != "": - if i+1 != 120: - text_2 = text_1[i+1:] + text_2 - break - # else: - # chulipangban_return_list.append(text_1) - if text_1_new == "": - chulipangban_return_list.append([text_1, snetence_id, short_num]) - if text_2 != "": - short_num += 1 - chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) - return chulipangban_return_list - - -def chulipangban_test_1(snetence_id, text): - # 引号处理 - - dialogs_text, dialogs_index, other_index = get_dialogs_index(text) - for dialogs_text_dan in dialogs_text: - text_dan_list = text.split(dialogs_text_dan) - if "。" in dialogs_text_dan: - dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") - text = dialogs_text_dan.join(text_dan_list) - - # text_new_str = "".join(text_new) - - sentence_list = text.split("。") - # sentence_list_new = [] - # for i in sentence_list: - # if i != "": - # sentence_list_new.append(i) - # sentence_list = sentence_list_new - sentence_batch_list = [] - sentence_batch_one = [] - sentence_batch_length = 0 - return_list = [] - for sentence in sentence_list: - if len(sentence) < 120: - sentence_batch_length += len(sentence) - sentence_batch_list.append([sentence, snetence_id, 0]) - # sentence_pre = autotitle.gen_synonyms_short(sentence) - # return_list.append(sentence_pre) - else: - - sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) - for sentence_short in sentence_split_list: - sentence_batch_list.append(sentence_short) - return sentence_batch_list - - -def paragraph_test_(text:list, text_new:list): - - for i in range(len(text)): - text = chulipangban_test_1(text, i) - text = "。".join(text) - text_new.append(text) - - # text_new_str = "".join(text_new) - return text_new - -def paragraph_test(texts:dict): - - text_new = [] - for i, text in texts.items(): - try: - text_list = chulipangban_test_1(i, text) - except: - print(i, text) - continue - text_new.extend(text_list) - - # text_new_str = "".join(text_new) - return text_new - - -def batch_data_process(text_list): - sentence_batch_length = 0 - sentence_batch_one = [] - sentence_batch_list = [] - - for sentence in text_list: - sentence_batch_length += len(sentence[0]) - sentence_batch_one.append(sentence) - if sentence_batch_length > 500: - sentence_batch_length = 0 - sentence_ = sentence_batch_one.pop(-1) - sentence_batch_list.append(sentence_batch_one) - sentence_batch_one = [] - sentence_batch_one.append(sentence_) - sentence_batch_list.append(sentence_batch_one) - return sentence_batch_list - -def batch_predict(batch_data_list): - ''' - 一个bacth数据预测 - @param data_text: - @return: - ''' - batch_data_list_new = [] - batch_data_text_list = [] - batch_data_snetence_id_list = [] - for i in batch_data_list: - batch_data_text_list.append(i[0]) - batch_data_snetence_id_list.append(i[1:]) - # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) - batch_pre_data_list = batch_data_text_list - for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): - batch_data_list_new.append([text] + sentence_id) - - return batch_data_list_new - - -def one_predict(data_text): - ''' - 一个条数据预测 - @param data_text: - @return: - ''' - if data_text[0] != "": - data_inputs = data_text[0].replace("&", "。") - pre_data = autotitle.generate(data_inputs) - else: - pre_data = "" - data_new = [pre_data] + data_text[1:] - return data_new - - -def predict_data_post_processing(text_list): - text_list_sentence = [] - # text_list_sentence.append([text_list[0][0], text_list[0][1]]) - - for i in range(len(text_list)): - if text_list[i][2] != 0: - text_list_sentence[-1][0] += text_list[i][0] - else: - text_list_sentence.append([text_list[i][0], text_list[i][1]]) - - return_list = {} - sentence_one = [] - sentence_id = "0" - for i in text_list_sentence: - if i[1] == sentence_id: - sentence_one.append(i[0]) - else: - return_list[sentence_id] = "。".join(sentence_one) - sentence_id = i[1] - sentence_one = [] - sentence_one.append(i[0]) - if sentence_one != []: - return_list[sentence_id] = "。".join(sentence_one) - return return_list - - -# def main(text:list): -# # text_list = paragraph_test(text) -# # batch_data = batch_data_process(text_list) -# # text_list = [] -# # for i in batch_data: -# # text_list.extend(i) -# # return_list = predict_data_post_processing(text_list) -# # return return_list - -def main(text: dict): - text_list = paragraph_test(text) - text_list_new = [] - for i in tqdm(text_list): - pre = one_predict(i) - text_list_new.append(pre) - return_list = predict_data_post_processing(text_list_new) - return return_list - - -if __name__ == '__main__': - - filename = './data/yy_data.json' - with open(filename) as file_obj: - yy_data = json.load(file_obj) - rels = main(yy_data) diff --git a/测试10000篇数据.py b/测试10000篇数据.py new file mode 100644 index 0000000..17e4c80 --- /dev/null +++ b/测试10000篇数据.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/3 14:22 +@Author : +@FileName: +@Software: +@Describe: +""" +import os +# os.environ["TF_KERAS"] = "1" +# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +from flask import Flask, jsonify +from predict_t5 import autotitle +import re +import json +from tqdm import tqdm + + + +db_key_query = 'query' +db_key_result = 'result' +batch_size = 32 + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。",",","?","!","…"] + +config = { + "batch_szie": 1000 +} + + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): + fuhao = [",","?","!","…"] + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + if text_2 == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + return chulipangban_return_list + for i in range(len(text_1)-1, -1, -1): + if text_1[i] in fuhao: + if i in dialogs_index: + continue + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, snetence_id, short_num]) + if text_2 != "": + if i+1 != 120: + text_2 = text_1[i+1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(snetence_id, text): + # 引号处理 + + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + for dialogs_text_dan in dialogs_text: + text_dan_list = text.split(dialogs_text_dan) + if "。" in dialogs_text_dan: + dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") + text = dialogs_text_dan.join(text_dan_list) + + # text_new_str = "".join(text_new) + + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + for sentence in sentence_list: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence, snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + + sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + return sentence_batch_list + + +def paragraph_test_(text:list, text_new:list): + + for i in range(len(text)): + text = chulipangban_test_1(text, i) + text = "。".join(text) + text_new.append(text) + + # text_new_str = "".join(text_new) + return text_new + +def paragraph_test(texts:dict): + + text_new = [] + for i, text in texts.items(): + try: + text_list = chulipangban_test_1(i, text) + except: + print(i, text) + continue + text_new.extend(text_list) + + # text_new_str = "".join(text_new) + return text_new + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def one_predict(data_text): + ''' + 一个条数据预测 + @param data_text: + @return: + ''' + if data_text[0] != "": + data_inputs = data_text[0].replace("&", "。") + pre_data = autotitle.generate(data_inputs) + else: + pre_data = "" + data_new = [pre_data] + data_text[1:] + return data_new + + +def predict_data_post_processing(text_list): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + + for i in range(len(text_list)): + if text_list[i][2] != 0: + text_list_sentence[-1][0] += text_list[i][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = {} + sentence_one = [] + sentence_id = "0" + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + return_list[sentence_id] = "。".join(sentence_one) + sentence_id = i[1] + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list[sentence_id] = "。".join(sentence_one) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list + +def main(text: dict): + text_list = paragraph_test(text) + text_list_new = [] + for i in tqdm(text_list): + pre = one_predict(i) + text_list_new.append(pre) + return_list = predict_data_post_processing(text_list_new) + return return_list + + +if __name__ == '__main__': + + filename = './data/yy_data.json' + with open(filename) as file_obj: + yy_data = json.load(file_obj) + rels = main(yy_data)