diff --git a/ceshi10000.py b/ceshi10000.py new file mode 100644 index 0000000..17e4c80 --- /dev/null +++ b/ceshi10000.py @@ -0,0 +1,250 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/3 14:22 +@Author : +@FileName: +@Software: +@Describe: +""" +import os +# os.environ["TF_KERAS"] = "1" +# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +from flask import Flask, jsonify +from predict_t5 import autotitle +import re +import json +from tqdm import tqdm + + + +db_key_query = 'query' +db_key_result = 'result' +batch_size = 32 + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。",",","?","!","…"] + +config = { + "batch_szie": 1000 +} + + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): + fuhao = [",","?","!","…"] + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + if text_2 == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + return chulipangban_return_list + for i in range(len(text_1)-1, -1, -1): + if text_1[i] in fuhao: + if i in dialogs_index: + continue + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, snetence_id, short_num]) + if text_2 != "": + if i+1 != 120: + text_2 = text_1[i+1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(snetence_id, text): + # 引号处理 + + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + for dialogs_text_dan in dialogs_text: + text_dan_list = text.split(dialogs_text_dan) + if "。" in dialogs_text_dan: + dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") + text = dialogs_text_dan.join(text_dan_list) + + # text_new_str = "".join(text_new) + + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + for sentence in sentence_list: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence, snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + + sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + return sentence_batch_list + + +def paragraph_test_(text:list, text_new:list): + + for i in range(len(text)): + text = chulipangban_test_1(text, i) + text = "。".join(text) + text_new.append(text) + + # text_new_str = "".join(text_new) + return text_new + +def paragraph_test(texts:dict): + + text_new = [] + for i, text in texts.items(): + try: + text_list = chulipangban_test_1(i, text) + except: + print(i, text) + continue + text_new.extend(text_list) + + # text_new_str = "".join(text_new) + return text_new + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def one_predict(data_text): + ''' + 一个条数据预测 + @param data_text: + @return: + ''' + if data_text[0] != "": + data_inputs = data_text[0].replace("&", "。") + pre_data = autotitle.generate(data_inputs) + else: + pre_data = "" + data_new = [pre_data] + data_text[1:] + return data_new + + +def predict_data_post_processing(text_list): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + + for i in range(len(text_list)): + if text_list[i][2] != 0: + text_list_sentence[-1][0] += text_list[i][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = {} + sentence_one = [] + sentence_id = "0" + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + return_list[sentence_id] = "。".join(sentence_one) + sentence_id = i[1] + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list[sentence_id] = "。".join(sentence_one) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list + +def main(text: dict): + text_list = paragraph_test(text) + text_list_new = [] + for i in tqdm(text_list): + pre = one_predict(i) + text_list_new.append(pre) + return_list = predict_data_post_processing(text_list_new) + return return_list + + +if __name__ == '__main__': + + filename = './data/yy_data.json' + with open(filename) as file_obj: + yy_data = json.load(file_obj) + rels = main(yy_data) diff --git a/data_do/合并已经有的yy降重的数据.py b/data_do/合并已经有的yy降重的数据.py new file mode 100644 index 0000000..5ca34a5 --- /dev/null +++ b/data_do/合并已经有的yy降重的数据.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/3 14:55 +@Author : +@FileName: +@Software: +@Describe: +""" +import pandas as pd +import json + + +yy_data_1 = "../data/论文_yy_小说.xlsx" +yy_data_2 = "../data/论文_yy_小说_1.xlsx" +df_1 = pd.read_excel(yy_data_1).values.tolist() +df_2 = pd.read_excel(yy_data_2).values.tolist() +df = df_1 + df_2 + +return_data = {} +for i in range(len(df)): + return_data[str(i)] = df[i][0] + +# import json +# +# names = ['joker','joe','nacy','timi'] +# +filename='yy_data.json' +with open(filename, 'w') as file_obj: + json.dump(return_data,file_obj) + diff --git a/flask_predict_no_batch_t5.py b/flask_predict_no_batch_t5.py index add4459..a2a44fb 100644 --- a/flask_predict_no_batch_t5.py +++ b/flask_predict_no_batch_t5.py @@ -8,9 +8,21 @@ from flask import request import requests from flask import request from predict_t5 import autotitle +import redis +import uuid +import json +from threading import Thread +import time +import re -import re +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1) +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_result = 'result' +batch_size = 32 + app = Flask(__name__) app.config["JSON_AS_ASCII"] = False @@ -257,35 +269,46 @@ def sentence(): return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} return jsonify(return_text) +def classify(): # 调用模型,设置最大batch_size + while True: + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + continue + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + data_dict = json.loads(query) + query_id = data_dict['id'] + texts = data_dict['text'] # 拼接若干text 为batch + text_type = data_dict["text_type"] + if isinstance(texts, dict): + texts_list = [] + if texts is None: + return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402} + else: + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + if text_type == 'chapter': + texts_list = main(texts) + + return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} + else: + return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401} + redis_.set(query_id, json.dumps(return_text, ensure_ascii=False)) + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + print(request.remote_addr) + texts = request.json["texts"] + text_type = request.json["text_type"] + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id + redis_.rpush(db_key_query, json.dumps(d, ensure_ascii=False)) # 加入redis + result_text = d + return jsonify(result_text) # 返回结果 -# @app.route('/chapter/', methods=['POST']) -# def chapter(): -# texts = request.json["texts"] -# -# print("原始语句" + str(texts)) -# # question = question.strip('。、!??') -# -# if isinstance(texts, str): -# texts_list = [] -# y_pred_label_list = [] -# position_list = [] -# -# # texts = texts.replace('\'', '\"') -# if texts is None: -# return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} -# return jsonify(return_text) -# else: -# texts = texts.split("\n") -# for text in texts: -# text = text.strip() -# return_str = autotitle.generate_random_shortest(text) -# texts_list.append(return_str) -# texts_str = "\n".join(texts_list) -# return_text = {"texts": texts_str, "probabilities": None, "status_code": True} -# else: -# return_text = {"texts": "输入格式应该为str", "probabilities": None, "status_code": False} -# return jsonify(return_text) +t = Thread(target=classify) +t.start() if __name__ == "__main__": fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') diff --git a/flask_predict_redis.py b/flask_predict_redis.py new file mode 100644 index 0000000..6b8e4e8 --- /dev/null +++ b/flask_predict_redis.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/2 19:31 +@Author : +@FileName: +@Software: +@Describe: +""" +# +# import redis +# +# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0) +# redis_conn = redis.Redis(connection_pool=redis_pool) +# +# +# name_dict = { +# 'name_4' : 'Zarten_4', +# 'name_5' : 'Zarten_5' +# } +# redis_conn.mset(name_dict) + +import flask +import redis +import uuid +import json +from threading import Thread +import time + +app = flask.Flask(__name__) +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50) +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_result = 'result' +batch_size = 32 + + + +def classify(): # 调用模型,设置最大batch_size + while True: + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + continue + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + data_dict = json.loads(query) + query_id = data_dict['id'] + text = data_dict['text'] # 拼接若干text 为batch + result = text + "1111111" # 调用模型 + time.sleep(5) + # for (id_, res) in zip(query_ids, result): + # res['score'] = str(res['score']) + # redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 + # d = {"id": query_id, "text": result} + redis_.set(query_id, json.dumps(result)) # 加入redis + +@app.route("/predict", methods=["POST"]) +def handle_query(): + text = flask.request.json['text'] # 获取用户query中的文本 例如"I love you" + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + print(id_) + d = {'id': id_, 'text': text} # 绑定文本和query id + redis_.rpush(db_key_query, json.dumps(d)) # 加入redis + # while True: + # result = redis_.get(id_) # 获取该query的模型结果 + # if result is not None: + # redis_.delete(id_) + # result_text = {'code': "200", 'data': result.decode('UTF-8')} + # break + result_text = {'id': id_, 'text': text} + return flask.jsonify(result_text) # 返回结果 + + +if __name__ == "__main__": + t = Thread(target=classify) + t.start() + app.run(debug=False, host='127.0.0.1', port=9000) \ No newline at end of file diff --git a/flask_predict_t5.py b/flask_predict_t5.py new file mode 100644 index 0000000..286a1c9 --- /dev/null +++ b/flask_predict_t5.py @@ -0,0 +1,9 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/2 16:40 +@Author : +@FileName: +@Software: +@Describe: +""" diff --git a/gunicorn_check_uuid_config.py b/gunicorn_check_uuid_config.py new file mode 100644 index 0000000..35684ca --- /dev/null +++ b/gunicorn_check_uuid_config.py @@ -0,0 +1,24 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/6 17:58 +@Author : +@FileName: +@Software: +@Describe: +""" +# 并行工作线程数 +workers = 1 +# 监听内网端口5000【按需要更改】 +bind = '0.0.0.0:14001' + +worker_class = "gevent" +# 设置守护进程【关闭连接时,程序仍在运行】 +daemon = True +# 设置超时时间120s,默认为30s。按自己的需求进行设置 +timeout = 120 +# 设置访问日志和错误信息日志路径 +accesslog = './check_logs/acess.log' +errorlog = './check_logs/error.log' +# access_log_format = '%(h) - %(t)s - %(u)s - %(s)s %(H)s' +# errorlog = '-' # 记录到标准输出 \ No newline at end of file diff --git a/gunicorn_config.py b/gunicorn_config.py new file mode 100644 index 0000000..c4fe472 --- /dev/null +++ b/gunicorn_config.py @@ -0,0 +1,15 @@ +# 并行工作线程数 +workers = 1 +# 监听内网端口5000【按需要更改】 +bind = '0.0.0.0:14000' + +worker_class = "gevent" +# 设置守护进程【关闭连接时,程序仍在运行】 +daemon = True +# 设置超时时间120s,默认为30s。按自己的需求进行设置 +timeout = 120 +# 设置访问日志和错误信息日志路径 +accesslog = './logs/acess.log' +errorlog = './logs/error.log' +# access_log_format = '%(h) - %(t)s - %(u)s - %(s)s %(H)s' +# errorlog = '-' # 记录到标准输出 \ No newline at end of file diff --git a/predict_t5.py b/predict_t5.py index da7146d..dce0f94 100644 --- a/predict_t5.py +++ b/predict_t5.py @@ -11,7 +11,7 @@ import os # os.environ["TF_KERAS"] = "1" -os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["CUDA_VISIBLE_DEVICES"] = "1" import glob from numpy import random random.seed(1001) diff --git a/python_to_redis.py b/python_to_redis.py new file mode 100644 index 0000000..693f53e --- /dev/null +++ b/python_to_redis.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/2 19:31 +@Author : +@FileName: +@Software: +@Describe: +""" +# +# import redis +# +# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0) +# redis_conn = redis.Redis(connection_pool=redis_pool) +# +# +# name_dict = { +# 'name_4' : 'Zarten_4', +# 'name_5' : 'Zarten_5' +# } +# redis_conn.mset(name_dict) + +import flask +import redis +import uuid +import json +from threading import Thread +import time + +app = flask.Flask(__name__) +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50) +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_result = 'result' +batch_size = 32 + + + +def classify(batch_size): # 调用模型,设置最大batch_size + while True: + texts = [] + query_ids = [] + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + continue + for i in range(min(redis_.llen(db_key_query), batch_size)): + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + query_ids.append(json.loads(query)['id']) + texts.append(json.loads(query)['text']) # 拼接若干text 为batch + result = texts # 调用模型 + time.sleep(20) + for (id_, res) in zip(query_ids, result): + res['score'] = str(res['score']) + redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + text = flask.request.json['text'] # 获取用户query中的文本 例如"I love you" + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + d = {'id': id_, 'text': text} # 绑定文本和query id + redis_.rpush(db_key_query, json.dumps(d)) # 加入redis + while True: + result = redis_.get(id_) # 获取该query的模型结果 + if result is not None: + redis_.delete(id_) + result_text = {'code': "200", 'data': result.decode('UTF-8')} + break + return flask.jsonify(result_text) # 返回结果 + + +if __name__ == "__main__": + t = Thread(target=classify, args=(batch_size,)) + t.start() + app.run(debug=False, host='127.0.0.1', port=9000) \ No newline at end of file diff --git a/redis_check_uuid.py b/redis_check_uuid.py new file mode 100644 index 0000000..0d31c8d --- /dev/null +++ b/redis_check_uuid.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/2 19:31 +@Author : +@FileName: +@Software: +@Describe: +""" +# +# import redis +# +# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0) +# redis_conn = redis.Redis(connection_pool=redis_pool) +# +# +# name_dict = { +# 'name_4' : 'Zarten_4', +# 'name_5' : 'Zarten_5' +# } +# redis_conn.mset(name_dict) + +import flask +import redis +import uuid +import json +from threading import Thread +import time + +app = flask.Flask(__name__) +pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1) +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + + +@app.route("/search", methods=["POST"]) +def handle_query(): + id_ = flask.request.json['id'] # 获取用户query中的文本 例如"I love you" + result = redis_.get(id_) # 获取该query的模型结果 + if result is not None: + redis_.delete(id_) + result_dict = result.decode('UTF-8') + result_dict = json.loads(result_dict) + texts = result_dict["texts"] + probabilities = result_dict["probabilities"] + status_code = result_dict["status_code"] + result_text = {'code': status_code, 'text': texts, 'probabilities': probabilities} + else: + result_text = {'code': "201", 'text': "", 'probabilities': None} + return flask.jsonify(result_text) # 返回结果 + + +if __name__ == "__main__": + app.run(debug=False, host='0.0.0.0', port=14001) \ No newline at end of file diff --git a/request_drop.py b/request_drop.py index 99f158e..240e385 100644 --- a/request_drop.py +++ b/request_drop.py @@ -40,24 +40,39 @@ def dialog_line_parse(url, text): return [] -ceshi_1 = [ - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。", - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。" - "我" * 110 - ] - -ceshi_2 = [ - "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。" - ] - -jishu = 0 -for i in ceshi_1: - for j in i: - jishu += 1 -print(jishu) +# ceshi_1 = [ +# "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒、义救同学的英雄,我们清江一中决不让英雄流血又流泪!”。", +# "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。", +# "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。" +# "我" * 110 +# ] +# +# ceshi_2 = [ +# "李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会。关于这件事,校长特别交代过了,我也非常认同。你这是见义勇为,是勇斗歹徒,你这是见义勇为,你这是见义勇为,你这是见义勇为,你这是见义勇为,你这是见义勇为,你这是见义勇为,你这是见义勇为,你这是见义勇为,、义救同学的英雄,我们清江一中决不让英雄流血又流泪。”" +# ] +# +# ceshi_3 = [ +# "方中乌药疏肝行气、散寒止痛为君药" +# ] +# +# jishu = 0 +# for i in ceshi_1: +# for j in i: +# jishu += 1 +# print(jishu) +# +# t1 = time() +# print(dialog_line_parse("http://192.168.31.116:14000/droprepeat/",{"texts": ceshi_2, "text_type": "focus"})) +# t2 = time() +# print(t2 -t1) + +ceshi_2 = { + "0":"李正旺你真是傻逼讪笑,挥手道", + "1":"李正旺你真是傻逼讪笑,挥手道", + "2":"李正旺你真是傻逼讪笑,挥手道" +} t1 = time() -print(dialog_line_parse("http://114.116.25.228:14000/droprepeat/",{"texts": ceshi_1, "text_type": "focus"})) +print(dialog_line_parse("http://114.116.25.228:14000/predict",{"texts": ceshi_2, "text_type": "chapter"})) t2 = time() print(t2 -t1) \ No newline at end of file diff --git a/run_app_flask.sh b/run_app_flask.sh new file mode 100644 index 0000000..ee35f0a --- /dev/null +++ b/run_app_flask.sh @@ -0,0 +1 @@ +gunicorn flask_predict_no_batch_t5:app -c gunicorn_config.py \ No newline at end of file diff --git a/run_check_uuid_app.sh b/run_check_uuid_app.sh new file mode 100644 index 0000000..20d5689 --- /dev/null +++ b/run_check_uuid_app.sh @@ -0,0 +1 @@ +gunicorn redis_check_uuid:app -c gunicorn_check_uuid_config.py \ No newline at end of file diff --git a/task_seq2seq_t5.py b/task_seq2seq_t5.py index d9b7bf6..3a8e084 100644 --- a/task_seq2seq_t5.py +++ b/task_seq2seq_t5.py @@ -49,7 +49,7 @@ spm_path = 'mt5/mt5_base/sentencepiece_cn.model' keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json' -file = "data/train_yy.txt" +file = "data/train_yy_zong_sim_99.txt" try: with open(file, 'r', encoding="utf-8") as f: lines = [x.strip() for x in f if x.strip() != ''] @@ -205,7 +205,7 @@ class Evaluator(keras.callbacks.Callback): # 保存最优 if logs['loss'] <= self.lowest: self.lowest = logs['loss'] - model.save_weights('./output_t5/best_model_t5_dropout_0_3.weights') + model.save_weights('./output_t5/best_model_t5_zong_sim_99.weights') # 演示效果7 just_show()