diff --git a/.gitignore b/.gitignore index 1f593de..d4c02f6 100644 --- a/.gitignore +++ b/.gitignore @@ -1,142 +1,4 @@ -# ---> Python -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -/chinese_roberta_wwm_ext_L-12_H-768_A-12/ /data/ +/bert_base_script_fintune_tf/ +/chinese_roberta_wwm_ext_L-12_H-768_A-12/ +/output/ diff --git a/README.md b/README.md index 14e372d..107260f 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,29 @@ -# 改写项目 - 基于unilm模型以及t5的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 - 训练数据 train_yy.txt +# 小说改写项目 + +基于unilm模型的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下 +训练数据 train_cat_data_4.txt ## 训练 - 训练 t5: python task_seq2seq_t5.py - 训练 simbert: python simbert_train.py + 加入了质量检测训练:bash train.sh + 加入了质量检测训练:bash train_sim.sh ## 预测 - simbert: python predict_sim.py - t5: python predict_t5.py + + 加入了质量检测 python predict_tf_sim.py + 未加入质量检测 python predict_tf.py ## API serve - 请求句子uuid服务启动方式:bash run_app_nohub_t5.sh - 根据uuid查找改写结果服务启动方式:bash run_app_nohub_search_redis.sh -## 请求响应示例 - 请求句子uuid: https://console-docs.apipost.cn/preview/e3717e390cbdb50e/f4479038c8015f34 - 请求改写结果: https://console-docs.apipost.cn/preview/6b9de12817e8ef08/b158334d2c9534d2 + 目前的启动方式:bash run_app.sh + 一键启动方式:bash run_app_gunicorn.sh -## 从yy数据生成训练数据 - python data_do/处理yy数据原始数据.py - python data_do/进一步处理降重数据.py - python data_do/yy训练数据处理.py - python 筛选训练数据strsim.py - python 合并数据.py +## 请求示例 + requests.post( + "http://192.168.1.17:14000", + json={"texts": ["张三要爬上高位的,才能够翻云覆雨。"]}, + timeout=1000 + ) -## 测试11篇数据 -## 测试数据是否有bug - python 测试10000篇数据.py \ No newline at end of file +## 响应 + {'probabilities': None, 'texts': ['张三要上了巅峰,他就可以为所欲为了。']} \ No newline at end of file diff --git a/config/predict_sim_config.py b/config/predict_sim_config.py index d757455..7707ee6 100644 --- a/config/predict_sim_config.py +++ b/config/predict_sim_config.py @@ -12,7 +12,7 @@ import os pre_model_path = { "simbert": { - "linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/chinese_roberta_wwm_ext_L-12_H-768_A-12", + "linux": "/home/majiahui/project/drop_weight_rewrite/chinese_roberta_wwm_ext_L-12_H-768_A-12", "win32": r"E:\pycharm_workspace\premodel\keras\chinese_roberta_wwm_ext_L-12_H-768_A-12" }, } @@ -29,4 +29,4 @@ class DropSimBertConfig: self.savemodel_path = "./output_simbert_yy/best_simbertmodel_dropout_datasim_yinhao.weights" self.maxlen = 120 - self.cuda_id = "1" \ No newline at end of file + self.cuda_id = "0" \ No newline at end of file diff --git a/config/predict_t5_config.py b/config/predict_t5_config.py index eff5338..445f442 100644 --- a/config/predict_t5_config.py +++ b/config/predict_t5_config.py @@ -12,8 +12,7 @@ import os pre_model_path = { "t5": { - # "linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/mt5/mt5_base", - "linux": "mt5/mt5_base", + "linux": "/home/majiahui/project/drop_weight_rewrite/mt5/mt5_base", "win32": r"E:\pycharm_workspace\premodel\keras\mt5\mt5_base" }, @@ -28,7 +27,7 @@ class DropT5Config: self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000') self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model') self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json') - self.savemodel_path = "./output_t5/best_model_t5_0724.weights" + self.savemodel_path = "./output_t5/best_model_t5_dropout_0_3.weights" self.maxlen = 256 self.cuda_id = "0" @@ -41,6 +40,6 @@ class MultipleResultsDropT5Config: self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000') self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model') self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json') - self.savemodel_path = "./output_t5/best_model_t5_0724.weights" + self.savemodel_path = "./output_t5/best_model_t5_dropout_0_3.weights" self.maxlen = 256 - self.cuda_id = "0" \ No newline at end of file + self.cuda_id = "0" diff --git a/crontab_sh.sh b/crontab_sh.sh index 7760044..d31a443 100644 --- a/crontab_sh.sh +++ b/crontab_sh.sh @@ -1,3 +1,2 @@ -rm -rf /home/majiahui/drop_weight_rewrite/request_data_logs/* rm -rf /home/majiahui/drop_weight_rewrite/old_data_logs/* mv /home/majiahui/drop_weight_rewrite/new_data_logs/* /home/majiahui/drop_weight_rewrite/old_data_logs/* \ No newline at end of file diff --git a/evaluate_test.py b/evaluate_test.py index 3642347..376adad 100644 --- a/evaluate_test.py +++ b/evaluate_test.py @@ -1,8 +1,10 @@ +# -*- coding: utf-8 -*- + """ @Time : 2022/8/15 15:20 -@Author : -@FileName: -@Software: +@Author : +@FileName: +@Software: @Describe: """ import json @@ -19,13 +21,11 @@ from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder from keras.models import Model from rouge import Rouge # pip install rouge from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction -import difflib class Evaluator(keras.callbacks.Callback): """评估与保存 """ - def __init__(self): self.rouge = Rouge() self.smooth = SmoothingFunction().method1 @@ -69,79 +69,52 @@ class Evaluator(keras.callbacks.Callback): } def evaluate_t(self, data_1, data_2, topk=1): - - data_1_eval = ' '.join(data_1) - data_2_eval = ' '.join(data_2) total = 0 rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0 - scores = self.rouge.get_scores(hyps=[data_1_eval], refs=[data_2_eval]) + scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2]) rouge_1 += scores[0]['rouge-1']['f'] rouge_2 += scores[0]['rouge-2']['f'] rouge_l += scores[0]['rouge-l']['f'] bleu += sentence_bleu( - references=[data_1_eval.split(' ')], - hypothesis=data_2_eval.split(' '), + references=[data_1.split(' ')], + hypothesis=data_2.split(' '), smoothing_function=self.smooth ) # rouge_1 /= total # rouge_2 /= total # rouge_l /= total # bleu /= total - str_sim = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() - return [rouge_1, rouge_2, rouge_l, bleu, str_sim] + return [rouge_1, rouge_2, rouge_l, bleu] eval_class = Evaluator() # print(eval_class.evaluate_t("星 辰 的 话","星 辰 的 话 :")) -path = "data/700条效果对比.xlsx" -path_out = "data/700条效果对比测评结果_14.csv" +path = "data/一万字小说测试效果.xlsx" +path_out = "data/一万字小说测试效果测评.csv" data = pd.read_excel(path).values.tolist() +data_new = {"rouge_1": [0,0,0,0,0,0,0,0,0,0,0], + "rouge_2": [0,0,0,0,0,0,0,0,0,0,0], + "rouge_l": [0,0,0,0,0,0,0,0,0,0,0], + "bleu": [0,0,0,0,0,0,0,0,0,0,0]} +total = 0 -list_class = [0 for i in range(13)] -# print(list_class) -data_new = {"rouge_1": list_class.copy(), - "rouge_2": list_class.copy(), - "rouge_l": list_class.copy(), - "bleu": list_class.copy(), - "str_sim": list_class.copy()} -total = len(data) -print(len(data)) for i in data: - dan_list = [i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], i[9], i[10], i[11], i[12], i[-1]] + dan_list = i[2:-1] for j in range(len(dan_list)): - eval_list = eval_class.evaluate_t(dan_list[j], i[0]) - try: - data_new["rouge_1"][j] += eval_list[0] - data_new["rouge_2"][j] += eval_list[1] - data_new["rouge_l"][j] += eval_list[2] - data_new["bleu"][j] += eval_list[3] - data_new["str_sim"][j] += eval_list[4] - except: - pass + eval_list = eval_class.evaluate_t(' '.join(dan_list[j]), ' '.join(i[-1])) + data_new["rouge_1"][j] += eval_list[0] + data_new["rouge_2"][j] += eval_list[1] + data_new["rouge_l"][j] += eval_list[2] + data_new["bleu"][j] += eval_list[3] -data = {} +''' +生成文本(t5_未修正数据) 生成文本(unilm未修正数据) 生成文本(unilm修正数据) 生成文本(unilm修正数据_预训练) 生成文本(240w/24H) 生成文本(240W/48H) 生成文本(240W/24H/修) 生成文本(全部数据/72H/修) 生成文本(全部数据/72H/未修) 生成文本(t5修正数据) 生成文本(t5修正数据_190epoch) -def fune(x): - return x/total -for i in data_new: - data[i] = list(map(fune, data_new[i])) +''' -pd.DataFrame(data, - index=["simbert_5day", - "simbert_simdata4day", - "simbert_simdata5day", - "simbert_random20_5day", - "simbert_simdata4day_yinhao", - "simbert_simdata4day_yinhao_dropout", - "simsim模型", - "dropout_sim_03模型", - "dropout_sim_04模型", - "t5", - "t5_dropout", - "小说模型", - "yy"] - ).to_csv( - path_out) \ No newline at end of file +pd.DataFrame(data_new,index=["生成文本(t5_未修正数据)","生成文本(unilm未修正数据)","生成文本(unilm修正数据)", + "生成文本(unilm修正数据_预训练)","生成文本(240w/24H)","生成文本(240W/48H)","生成文本(240W/24H/修)", + "生成文本(全部数据/72H/修)", "生成文本(全部数据/72H/未修)", "生成文本(t5修正数据)", "生成文本(t5修正数据_190epoch)"]).to_csv(path_out) \ No newline at end of file diff --git a/flask_drop_rewrite_request.py b/flask_drop_rewrite_request.py index 1c058bc..9d8f93b 100644 --- a/flask_drop_rewrite_request.py +++ b/flask_drop_rewrite_request.py @@ -62,8 +62,8 @@ def get_host_ip(): return ip -chatgpt_url_predict = "http://{}:12006/predict".format(str(get_host_ip())) -chatgpt_url_search = "http://{}:12006/search".format(str(get_host_ip())) +chatgpt_url_predict = "http://{}:12003/predict".format(str(get_host_ip())) +chatgpt_url_search = "http://{}:12003/search".format(str(get_host_ip())) def smtp_f(name): diff --git a/flask_predict.py b/flask_predict.py index 77dab49..8fba2d0 100644 --- a/flask_predict.py +++ b/flask_predict.py @@ -116,7 +116,7 @@ def batch_data_process(text_list): for sentence in text_list: sentence_batch_length += len(sentence[0]) sentence_batch_one.append(sentence) - if sentence_batch_length > 10000: + if sentence_batch_length > 1000: sentence_batch_length = 0 sentence_ = sentence_batch_one.pop(-1) sentence_batch_list.append(sentence_batch_one) diff --git a/predict_drop_sim_sim.py b/predict_drop_sim_sim.py index 2465548..18b52ef 100644 --- a/predict_drop_sim_sim.py +++ b/predict_drop_sim_sim.py @@ -713,7 +713,7 @@ if __name__ == '__main__': df_list_new = [] print(len(df_list)) for i in tqdm(df_list): - pre = just_show_sentence([i[1]]) + pre = just_show_sentence([i[0]]) df_list_new.append([i[0] , pre]) diff --git a/predict_drop_weight_sim.py b/predict_drop_weight_sim.py index e5c5b06..155fcdb 100644 --- a/predict_drop_weight_sim.py +++ b/predict_drop_weight_sim.py @@ -85,7 +85,6 @@ class GenerateModel(object): self.checkpoint_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt' self.dict_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt' self.maxlen = 120 - self.novel_maxlen = 60 def device_setup(self): token_dict, keep_tokens = load_vocab( @@ -122,7 +121,7 @@ class GenerateModel(object): outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs) model = keras.models.Model(bert.model.inputs, outputs) - path_model = './output_simbert_yy/best_simbertmodel.weights' + path_model = './output_simbert/best_simbertmodel.weights' model.load_weights(path_model) return encoder,seq2seq, tokenizer @@ -535,6 +534,7 @@ class AutoTitle(AutoRegressiveDecoder): # return self.last_token(self.model).predict([token_ids, segment_ids]) def generate(self, text, topk=3): + text = text[0] token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256) output_ids = self.beam_search([token_ids, segment_ids], topk=topk) # 基于beam search @@ -649,12 +649,12 @@ def just_show(file): # pd.DataFrame(data,columns=["原始文本","生成文本"]).to_csv("data/text_测试一万字_unilm_修正数据_小说预训练_全部数据_epoch72_反向训练.csv") -def just_show_sentence(file: list) -> object: +def just_show_sentence(file: object) -> object: """ @param file:list """ - text = file[0] - pre = autotitle.generate(text) + + pre = autotitle.generate(file) return pre # pre = autotitle.gen_synonyms(file) @@ -701,95 +701,20 @@ def just_show_csv_beam(file): pd.DataFrame(data_new).to_csv("data/###第3章 非常尴尬_sim_topK_1.csv") -def chulichangju_1(text, chulipangban_return_list): - fuhao = [",",",","?","!","…"] - text_1 = text[:60] - text_2 = text[60:] - text_1_new = "" - for i in range(len(text_1)-1, -1, -1): - if text_1[i] in fuhao: - text_1_new = text_1[:i] - text_1_new += text_1[i] - if len(text_1_new) > 10: - text_1_new_pre = autotitle.generate(text_1_new) - else: - text_1_new_pre = text_1_new - chulipangban_return_list.append(text_1_new_pre) - if text_2 != "": - if i+1 != 60: - text_2 = text_1[i+1:] + text_2 - break - # else: - # chulipangban_return_list.append(text_1) - if text_1_new == "": - if len(text_1) > 10: - text_1_new_pre = autotitle.gen_synonyms_short(text_1) - else: - text_1_new_pre = text_1 - chulipangban_return_list.append(text_1_new_pre) - if text_2 != "": - chulipangban_return_list = chulichangju_1(text_2, chulipangban_return_list) - return chulipangban_return_list - - -def chulipangban_test_1(text): - sentence_list = text.split("。") - sentence_list_new = [] - for i in sentence_list: - if i != "": - sentence_list_new.append(i) - sentence_list = sentence_list_new - return_list = [] - for sentence in sentence_list: - if len(sentence) < 60: - if len(sentence) > 10: - sentence_pre = autotitle.generate(sentence) - else: - sentence_pre = sentence - return_list.append(sentence_pre) - else: - sentence_split_list = chulichangju_1(sentence,[]) - sentence_split_text = "".join(sentence_split_list) - return_list.append(sentence_split_text) - return return_list - - -def paragraph_test(text, text_new): +if __name__ == '__main__': - text = chulipangban_test_1(text) - text = "。".join(text) - text_new.append(text) + # text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"] + # print(just_show_sentence(text)) - # text_new_str = "".join(text_new) - return text_new + path = "./data/700条论文测试.xlsx" + df_list = pd.read_excel(path).values.tolist() + df_list_new = [] + print(len(df_list)) + for i in tqdm(df_list): + pre = just_show_sentence([i[0]]) -if __name__ == '__main__': - text = ["所以对学生对应用仪器分析解决实际问题的能力要求很高。","随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] - print(just_show_sentence(text)) + df_list_new.append([i[0], i[1], pre]) - # is_novel = False - # path = "./data/700条论文测试.xlsx" - # df_list = pd.read_excel(path).values.tolist() - # - # if is_novel == False: - # df_list_new = [] - # print(len(df_list)) - # for i in tqdm(df_list[:50]): - # pre = just_show_sentence([i[0]]) - # - # df_list_new.append([i[0], i[1], pre]) - # - # df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "模型"]) - # df.to_excel("./data/700条论文测试_14.xlsx", index=None) - # else: - # df_list_new = [] - # print(len(df_list)) - # for i in tqdm(df_list): - # text_list = paragraph_test(i[0], []) - # pre = "".join(text_list) - # pre += "。" - # df_list_new.append([i[0], i[1], pre]) - # - # df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "dropout_sim_模型"]) - # df.to_excel("./data/700条论文测试_5.xlsx", index=None) \ No newline at end of file + df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "小说模型"]) + df.to_excel("./data/700条论文测试_3.xlsx", index=None) diff --git a/predict_t5.py b/predict_t5.py index 78a0ced..a326d33 100644 --- a/predict_t5.py +++ b/predict_t5.py @@ -8,11 +8,7 @@ @Describe: """ #! -*- coding: utf-8 -*- -import os -from config.predict_t5_config import DropT5Config -config = DropT5Config() -os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id import glob from numpy import random random.seed(1001) @@ -414,10 +410,11 @@ class AutoTitle(AutoRegressiveDecoder): return output_str -def just_show_sentence(text): +def just_show_sentence(file): """ - @param text:list + @param file:list """ + text = file[0] pre = autotitle.generate(text) return pre @@ -429,7 +426,10 @@ def just_show_sentence_batch(file: list) -> object: if __name__ == '__main__': - + import os + from config.predict_t5_config import DropT5Config + config = DropT5Config() + os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id generatemodel = GenerateModel(config.config_path, config.checkpoint_path, config.spm_path, @@ -462,32 +462,31 @@ if __name__ == '__main__': # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ - import os - - file = "./data/11篇汇总txt_new.txt" - file_t5 = "./data/11篇汇总txt_new_predict_t5.txt" - file_t5_0724 = "./data/11篇汇总txt_new_predict_t5_0724.txt" - - try: - with open(file, 'r', encoding="utf-8") as f: - lines = [x.strip() for x in f if x.strip() != ''] - except: - with open(file, 'r', encoding="gbk") as f: - lines = [x.strip() for x in f if x.strip() != ''] - - zishu = 0 - data = [] - for i in tqdm(lines): - - zishu += len(i) - pre = just_show_sentence([i]) - data.append([i, pre]) - - with open(file_t5_0724, "w", encoding='utf-8') as file: - for i in data: - file.write("\t".join(i) + '\n') - file.close() - print(zishu) + # import os + # + # file = "./data/11篇汇总txt_new.txt" + # file_t5 = "./data/11篇汇总txt_new_predict_t5.txt" + # + # try: + # with open(file, 'r', encoding="utf-8") as f: + # lines = [x.strip() for x in f if x.strip() != ''] + # except: + # with open(file, 'r', encoding="gbk") as f: + # lines = [x.strip() for x in f if x.strip() != ''] + # + # zishu = 0 + # data = [] + # for i in tqdm(lines): + # + # zishu += len(i) + # pre = just_show_sentence([i]) + # data.append([i, pre]) + # + # with open(file_t5, "w", encoding='utf-8') as file: + # for i in data: + # file.write("\t".join(i) + '\n') + # file.close() + # print(zishu) #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ @@ -497,13 +496,11 @@ if __name__ == '__main__': # "强调轻资产经营, 更加重视经营风险的规避", # "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择", # "是时代对于青年们的深切呼唤"] - # text = ["随着经济的发展,人们生活水平的提高,环境:问题也日益突出。", - # "环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。", - # "所以对学生对应用仪器分析解决实际问题的能力要求很高。", - # "随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] - # - # for i in text: - # print(just_show_sentence(i)) + text = ["随着经济的发展,人们生活水平的提高,环境:问题也日益突出。", + "环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。", + "所以对学生对应用仪器分析解决实际问题的能力要求很高。", + "随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] + print(just_show_sentence(text)) # print(just_show_sentence_top(text)) # print(just_show_chachong_random(text)) diff --git a/predict_tf_sim.py b/predict_tf_sim.py index 892dabe..a56c279 100644 --- a/predict_tf_sim.py +++ b/predict_tf_sim.py @@ -703,7 +703,7 @@ def just_show_csv_beam(file): if __name__ == '__main__': # file = "train_2842.txt" # just_show(file) - text = ["随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] + text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"] just_show_sentence(text) # "简言之,她不好过,李四也别想好过!" # s = "张三的对话" diff --git a/redis_check_uuid.py b/redis_check_uuid.py index e5d3141..1237be5 100644 --- a/redis_check_uuid.py +++ b/redis_check_uuid.py @@ -84,4 +84,4 @@ def handle_query(): if __name__ == "__main__": - app.run(debug=False, host='0.0.0.0', port=14001) \ No newline at end of file + app.run(debug=False, host='0.0.0.0', port=14001) diff --git a/run_app_nohub_search_redis.sh b/run_app_nohub_search_redis.sh index 84b0dcb..349d74c 100644 --- a/run_app_nohub_search_redis.sh +++ b/run_app_nohub_search_redis.sh @@ -1 +1 @@ -nohup python redis_check_uuid_mistral.py > myout.redis_check_uuid_mistral.logs 2>&1 & +nohup python redis_check_uuid.py > myout.redis_check_uuid.logs 2>&1 & \ No newline at end of file diff --git a/task_seq2seq_autotitle.py b/task_seq2seq_autotitle.py index 93e8060..fdee89a 100644 --- a/task_seq2seq_autotitle.py +++ b/task_seq2seq_autotitle.py @@ -20,7 +20,6 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction import os -# os.environ["TF_KERAS"] = "1" os.environ["CUDA_VISIBLE_DEVICES"] = "0" from keras.backend.tensorflow_backend import set_session config = tf.ConfigProto() @@ -30,17 +29,19 @@ set_session(tf.Session(config=config)) # 此处不同 # 基本参数 maxlen = 256 -batch_size = 8 +batch_size = 32 steps_per_epoch = 20000 epochs = 10000 # bert配置 -config_path = 'bert_config_dropout_0_3.json' -checkpoint_path = './chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt' -dict_path = './chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt' +config_path = './bert_base_script_fintune_tf/config.json' +checkpoint_path = './bert_base_script_fintune_tf/bert_base_script_fintune_tf.ckpt' +dict_path = './bert_base_script_fintune_tf/vocab.txt' +# # 训练样本。THUCNews数据集,每个样本保存为一个txt。 +# txts = glob.glob('/root/thuctc/THUCNews/*/*.txt') -file = "data/train_yy.txt" +file = "data/train_cat_data_4.txt" try: with open(file, 'r', encoding="utf-8") as f: lines = [x.strip() for x in f if x.strip() != ''] diff --git a/task_seq2seq_t5.py b/task_seq2seq_t5.py index 450fae9..3a8e084 100644 --- a/task_seq2seq_t5.py +++ b/task_seq2seq_t5.py @@ -16,7 +16,7 @@ # 补充了评测指标bleu、rouge-1、rouge-2、rouge-l import os # os.environ["TF_KERAS"] = "1" -os.environ["CUDA_VISIBLE_DEVICES"] = "3" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" import json import numpy as np from tqdm import tqdm @@ -40,7 +40,7 @@ for gpu in gpus: max_c_len = 128 max_t_len = 128 batch_size = 28 -epochs = 10 +epochs = 10000 # 模型路径 config_path = 'mt5/mt5_base_dropout_0_3_config.json' @@ -49,7 +49,7 @@ spm_path = 'mt5/mt5_base/sentencepiece_cn.model' keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json' -file = "data/train_new/train_yy.txt" +file = "data/train_yy_zong_sim_99.txt" try: with open(file, 'r', encoding="utf-8") as f: lines = [x.strip() for x in f if x.strip() != ''] @@ -205,7 +205,7 @@ class Evaluator(keras.callbacks.Callback): # 保存最优 if logs['loss'] <= self.lowest: self.lowest = logs['loss'] - model.save_weights('./output_t5/best_model_t5_0724.weights') + model.save_weights('./output_t5/best_model_t5_zong_sim_99.weights') # 演示效果7 just_show()