Browse Source

普通版降重优化:使用vllm模型,解决英文问题,修复空字符bug

master
majiahui@haimaqingfan.com 8 months ago
parent
commit
6c119a3f20
  1. 144
      .gitignore
  2. 40
      README.md
  3. 4
      config/predict_sim_config.py
  4. 9
      config/predict_t5_config.py
  5. 1
      crontab_sh.sh
  6. 83
      evaluate_test.py
  7. 4
      flask_drop_rewrite_request.py
  8. 2
      flask_predict.py
  9. 2
      predict_drop_sim_sim.py
  10. 109
      predict_drop_weight_sim.py
  11. 77
      predict_t5.py
  12. 2
      predict_tf_sim.py
  13. 2
      redis_check_uuid.py
  14. 2
      run_app_nohub_search_redis.sh
  15. 13
      task_seq2seq_autotitle.py
  16. 8
      task_seq2seq_t5.py

144
.gitignore

@ -1,142 +1,4 @@
# ---> Python
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
/chinese_roberta_wwm_ext_L-12_H-768_A-12/
/data/
/bert_base_script_fintune_tf/
/chinese_roberta_wwm_ext_L-12_H-768_A-12/
/output/

40
README.md

@ -1,31 +1,29 @@
# 改写项目
基于unilm模型以及t5的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下
训练数据 train_yy.txt
# 小说改写项目
基于unilm模型的生成式任务,使用keras框架,数据处理脚本在data_do文件夹下
训练数据 train_cat_data_4.txt
## 训练
训练 t5: python task_seq2seq_t5.py
训练 simbert: python simbert_train.py
加入了质量检测训练:bash train.sh
加入了质量检测训练:bash train_sim.sh
## 预测
simbert: python predict_sim.py
t5: python predict_t5.py
加入了质量检测 python predict_tf_sim.py
未加入质量检测 python predict_tf.py
## API serve
请求句子uuid服务启动方式:bash run_app_nohub_t5.sh
根据uuid查找改写结果服务启动方式:bash run_app_nohub_search_redis.sh
## 请求响应示例
请求句子uuid: https://console-docs.apipost.cn/preview/e3717e390cbdb50e/f4479038c8015f34
请求改写结果: https://console-docs.apipost.cn/preview/6b9de12817e8ef08/b158334d2c9534d2
目前的启动方式:bash run_app.sh
一键启动方式:bash run_app_gunicorn.sh
## 从yy数据生成训练数据
python data_do/处理yy数据原始数据.py
python data_do/进一步处理降重数据.py
python data_do/yy训练数据处理.py
python 筛选训练数据strsim.py
python 合并数据.py
## 请求示例
requests.post(
"http://192.168.1.17:14000",
json={"texts": ["张三要爬上高位的,才能够翻云覆雨。"]},
timeout=1000
)
## 测试11篇数据
## 测试数据是否有bug
python 测试10000篇数据.py
## 响应
{'probabilities': None, 'texts': ['张三要上了巅峰,他就可以为所欲为了。']}

4
config/predict_sim_config.py

@ -12,7 +12,7 @@ import os
pre_model_path = {
"simbert": {
"linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/chinese_roberta_wwm_ext_L-12_H-768_A-12",
"linux": "/home/majiahui/project/drop_weight_rewrite/chinese_roberta_wwm_ext_L-12_H-768_A-12",
"win32": r"E:\pycharm_workspace\premodel\keras\chinese_roberta_wwm_ext_L-12_H-768_A-12"
},
}
@ -29,4 +29,4 @@ class DropSimBertConfig:
self.savemodel_path = "./output_simbert_yy/best_simbertmodel_dropout_datasim_yinhao.weights"
self.maxlen = 120
self.cuda_id = "1"
self.cuda_id = "0"

9
config/predict_t5_config.py

@ -12,8 +12,7 @@ import os
pre_model_path = {
"t5": {
# "linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/mt5/mt5_base",
"linux": "mt5/mt5_base",
"linux": "/home/majiahui/project/drop_weight_rewrite/mt5/mt5_base",
"win32": r"E:\pycharm_workspace\premodel\keras\mt5\mt5_base"
},
@ -28,7 +27,7 @@ class DropT5Config:
self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000')
self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model')
self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json')
self.savemodel_path = "./output_t5/best_model_t5_0724.weights"
self.savemodel_path = "./output_t5/best_model_t5_dropout_0_3.weights"
self.maxlen = 256
self.cuda_id = "0"
@ -41,6 +40,6 @@ class MultipleResultsDropT5Config:
self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000')
self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model')
self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json')
self.savemodel_path = "./output_t5/best_model_t5_0724.weights"
self.savemodel_path = "./output_t5/best_model_t5_dropout_0_3.weights"
self.maxlen = 256
self.cuda_id = "0"
self.cuda_id = "0"

1
crontab_sh.sh

@ -1,3 +1,2 @@
rm -rf /home/majiahui/drop_weight_rewrite/request_data_logs/*
rm -rf /home/majiahui/drop_weight_rewrite/old_data_logs/*
mv /home/majiahui/drop_weight_rewrite/new_data_logs/* /home/majiahui/drop_weight_rewrite/old_data_logs/*

83
evaluate_test.py

@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
"""
@Time : 2022/8/15 15:20
@Author :
@FileName:
@Software:
@Author :
@FileName:
@Software:
@Describe:
"""
import json
@ -19,13 +21,11 @@ from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder
from keras.models import Model
from rouge import Rouge # pip install rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import difflib
class Evaluator(keras.callbacks.Callback):
"""评估与保存
"""
def __init__(self):
self.rouge = Rouge()
self.smooth = SmoothingFunction().method1
@ -69,79 +69,52 @@ class Evaluator(keras.callbacks.Callback):
}
def evaluate_t(self, data_1, data_2, topk=1):
data_1_eval = ' '.join(data_1)
data_2_eval = ' '.join(data_2)
total = 0
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0
scores = self.rouge.get_scores(hyps=[data_1_eval], refs=[data_2_eval])
scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2])
rouge_1 += scores[0]['rouge-1']['f']
rouge_2 += scores[0]['rouge-2']['f']
rouge_l += scores[0]['rouge-l']['f']
bleu += sentence_bleu(
references=[data_1_eval.split(' ')],
hypothesis=data_2_eval.split(' '),
references=[data_1.split(' ')],
hypothesis=data_2.split(' '),
smoothing_function=self.smooth
)
# rouge_1 /= total
# rouge_2 /= total
# rouge_l /= total
# bleu /= total
str_sim = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio()
return [rouge_1, rouge_2, rouge_l, bleu, str_sim]
return [rouge_1, rouge_2, rouge_l, bleu]
eval_class = Evaluator()
# print(eval_class.evaluate_t("星 辰 的 话","星 辰 的 话 :"))
path = "data/700条效果对比.xlsx"
path_out = "data/700条效果对比测评结果_14.csv"
path = "data/一万字小说测试效果.xlsx"
path_out = "data/一万字小说测试效果测评.csv"
data = pd.read_excel(path).values.tolist()
data_new = {"rouge_1": [0,0,0,0,0,0,0,0,0,0,0],
"rouge_2": [0,0,0,0,0,0,0,0,0,0,0],
"rouge_l": [0,0,0,0,0,0,0,0,0,0,0],
"bleu": [0,0,0,0,0,0,0,0,0,0,0]}
total = 0
list_class = [0 for i in range(13)]
# print(list_class)
data_new = {"rouge_1": list_class.copy(),
"rouge_2": list_class.copy(),
"rouge_l": list_class.copy(),
"bleu": list_class.copy(),
"str_sim": list_class.copy()}
total = len(data)
print(len(data))
for i in data:
dan_list = [i[1], i[2], i[3], i[4], i[5], i[6], i[7], i[8], i[9], i[10], i[11], i[12], i[-1]]
dan_list = i[2:-1]
for j in range(len(dan_list)):
eval_list = eval_class.evaluate_t(dan_list[j], i[0])
try:
data_new["rouge_1"][j] += eval_list[0]
data_new["rouge_2"][j] += eval_list[1]
data_new["rouge_l"][j] += eval_list[2]
data_new["bleu"][j] += eval_list[3]
data_new["str_sim"][j] += eval_list[4]
except:
pass
eval_list = eval_class.evaluate_t(' '.join(dan_list[j]), ' '.join(i[-1]))
data_new["rouge_1"][j] += eval_list[0]
data_new["rouge_2"][j] += eval_list[1]
data_new["rouge_l"][j] += eval_list[2]
data_new["bleu"][j] += eval_list[3]
data = {}
'''
生成文本t5_未修正数据 生成文本unilm未修正数据 生成文本unilm修正数据 生成文本unilm修正数据_预训练 生成文本240w/24H 生成文本(240W/48H) 生成文本(240W/24H/) 生成文本全部数据/72H/) 生成文本全部数据/72H/未修 生成文本t5修正数据 生成文本t5修正数据_190epoch
def fune(x):
return x/total
for i in data_new:
data[i] = list(map(fune, data_new[i]))
'''
pd.DataFrame(data,
index=["simbert_5day",
"simbert_simdata4day",
"simbert_simdata5day",
"simbert_random20_5day",
"simbert_simdata4day_yinhao",
"simbert_simdata4day_yinhao_dropout",
"simsim模型",
"dropout_sim_03模型",
"dropout_sim_04模型",
"t5",
"t5_dropout",
"小说模型",
"yy"]
).to_csv(
path_out)
pd.DataFrame(data_new,index=["生成文本(t5_未修正数据)","生成文本(unilm未修正数据)","生成文本(unilm修正数据)",
"生成文本(unilm修正数据_预训练)","生成文本(240w/24H)","生成文本(240W/48H)","生成文本(240W/24H/修)",
"生成文本(全部数据/72H/修)", "生成文本(全部数据/72H/未修)", "生成文本(t5修正数据)", "生成文本(t5修正数据_190epoch)"]).to_csv(path_out)

4
flask_drop_rewrite_request.py

@ -62,8 +62,8 @@ def get_host_ip():
return ip
chatgpt_url_predict = "http://{}:12006/predict".format(str(get_host_ip()))
chatgpt_url_search = "http://{}:12006/search".format(str(get_host_ip()))
chatgpt_url_predict = "http://{}:12003/predict".format(str(get_host_ip()))
chatgpt_url_search = "http://{}:12003/search".format(str(get_host_ip()))
def smtp_f(name):

2
flask_predict.py

@ -116,7 +116,7 @@ def batch_data_process(text_list):
for sentence in text_list:
sentence_batch_length += len(sentence[0])
sentence_batch_one.append(sentence)
if sentence_batch_length > 10000:
if sentence_batch_length > 1000:
sentence_batch_length = 0
sentence_ = sentence_batch_one.pop(-1)
sentence_batch_list.append(sentence_batch_one)

2
predict_drop_sim_sim.py

@ -713,7 +713,7 @@ if __name__ == '__main__':
df_list_new = []
print(len(df_list))
for i in tqdm(df_list):
pre = just_show_sentence([i[1]])
pre = just_show_sentence([i[0]])
df_list_new.append([i[0] , pre])

109
predict_drop_weight_sim.py

@ -85,7 +85,6 @@ class GenerateModel(object):
self.checkpoint_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
self.dict_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
self.maxlen = 120
self.novel_maxlen = 60
def device_setup(self):
token_dict, keep_tokens = load_vocab(
@ -122,7 +121,7 @@ class GenerateModel(object):
outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs)
model = keras.models.Model(bert.model.inputs, outputs)
path_model = './output_simbert_yy/best_simbertmodel.weights'
path_model = './output_simbert/best_simbertmodel.weights'
model.load_weights(path_model)
return encoder,seq2seq, tokenizer
@ -535,6 +534,7 @@ class AutoTitle(AutoRegressiveDecoder):
# return self.last_token(self.model).predict([token_ids, segment_ids])
def generate(self, text, topk=3):
text = text[0]
token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256)
output_ids = self.beam_search([token_ids, segment_ids],
topk=topk) # 基于beam search
@ -649,12 +649,12 @@ def just_show(file):
# pd.DataFrame(data,columns=["原始文本","生成文本"]).to_csv("data/text_测试一万字_unilm_修正数据_小说预训练_全部数据_epoch72_反向训练.csv")
def just_show_sentence(file: list) -> object:
def just_show_sentence(file: object) -> object:
"""
@param file:list
"""
text = file[0]
pre = autotitle.generate(text)
pre = autotitle.generate(file)
return pre
# pre = autotitle.gen_synonyms(file)
@ -701,95 +701,20 @@ def just_show_csv_beam(file):
pd.DataFrame(data_new).to_csv("data/###第3章 非常尴尬_sim_topK_1.csv")
def chulichangju_1(text, chulipangban_return_list):
fuhao = [",","","","",""]
text_1 = text[:60]
text_2 = text[60:]
text_1_new = ""
for i in range(len(text_1)-1, -1, -1):
if text_1[i] in fuhao:
text_1_new = text_1[:i]
text_1_new += text_1[i]
if len(text_1_new) > 10:
text_1_new_pre = autotitle.generate(text_1_new)
else:
text_1_new_pre = text_1_new
chulipangban_return_list.append(text_1_new_pre)
if text_2 != "":
if i+1 != 60:
text_2 = text_1[i+1:] + text_2
break
# else:
# chulipangban_return_list.append(text_1)
if text_1_new == "":
if len(text_1) > 10:
text_1_new_pre = autotitle.gen_synonyms_short(text_1)
else:
text_1_new_pre = text_1
chulipangban_return_list.append(text_1_new_pre)
if text_2 != "":
chulipangban_return_list = chulichangju_1(text_2, chulipangban_return_list)
return chulipangban_return_list
def chulipangban_test_1(text):
sentence_list = text.split("")
sentence_list_new = []
for i in sentence_list:
if i != "":
sentence_list_new.append(i)
sentence_list = sentence_list_new
return_list = []
for sentence in sentence_list:
if len(sentence) < 60:
if len(sentence) > 10:
sentence_pre = autotitle.generate(sentence)
else:
sentence_pre = sentence
return_list.append(sentence_pre)
else:
sentence_split_list = chulichangju_1(sentence,[])
sentence_split_text = "".join(sentence_split_list)
return_list.append(sentence_split_text)
return return_list
def paragraph_test(text, text_new):
if __name__ == '__main__':
text = chulipangban_test_1(text)
text = "".join(text)
text_new.append(text)
# text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"]
# print(just_show_sentence(text))
# text_new_str = "".join(text_new)
return text_new
path = "./data/700条论文测试.xlsx"
df_list = pd.read_excel(path).values.tolist()
df_list_new = []
print(len(df_list))
for i in tqdm(df_list):
pre = just_show_sentence([i[0]])
if __name__ == '__main__':
text = ["所以对学生对应用仪器分析解决实际问题的能力要求很高。","随着经济的发展,人们生活水平的提高,环境问题也日益突出。"]
print(just_show_sentence(text))
df_list_new.append([i[0], i[1], pre])
# is_novel = False
# path = "./data/700条论文测试.xlsx"
# df_list = pd.read_excel(path).values.tolist()
#
# if is_novel == False:
# df_list_new = []
# print(len(df_list))
# for i in tqdm(df_list[:50]):
# pre = just_show_sentence([i[0]])
#
# df_list_new.append([i[0], i[1], pre])
#
# df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "模型"])
# df.to_excel("./data/700条论文测试_14.xlsx", index=None)
# else:
# df_list_new = []
# print(len(df_list))
# for i in tqdm(df_list):
# text_list = paragraph_test(i[0], [])
# pre = "".join(text_list)
# pre += "。"
# df_list_new.append([i[0], i[1], pre])
#
# df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "dropout_sim_模型"])
# df.to_excel("./data/700条论文测试_5.xlsx", index=None)
df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "小说模型"])
df.to_excel("./data/700条论文测试_3.xlsx", index=None)

77
predict_t5.py

@ -8,11 +8,7 @@
@Describe:
"""
#! -*- coding: utf-8 -*-
import os
from config.predict_t5_config import DropT5Config
config = DropT5Config()
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id
import glob
from numpy import random
random.seed(1001)
@ -414,10 +410,11 @@ class AutoTitle(AutoRegressiveDecoder):
return output_str
def just_show_sentence(text):
def just_show_sentence(file):
"""
@param text:list
@param file:list
"""
text = file[0]
pre = autotitle.generate(text)
return pre
@ -429,7 +426,10 @@ def just_show_sentence_batch(file: list) -> object:
if __name__ == '__main__':
import os
from config.predict_t5_config import DropT5Config
config = DropT5Config()
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id
generatemodel = GenerateModel(config.config_path,
config.checkpoint_path,
config.spm_path,
@ -462,32 +462,31 @@ if __name__ == '__main__':
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
file = "./data/11篇汇总txt_new.txt"
file_t5 = "./data/11篇汇总txt_new_predict_t5.txt"
file_t5_0724 = "./data/11篇汇总txt_new_predict_t5_0724.txt"
try:
with open(file, 'r', encoding="utf-8") as f:
lines = [x.strip() for x in f if x.strip() != '']
except:
with open(file, 'r', encoding="gbk") as f:
lines = [x.strip() for x in f if x.strip() != '']
zishu = 0
data = []
for i in tqdm(lines):
zishu += len(i)
pre = just_show_sentence([i])
data.append([i, pre])
with open(file_t5_0724, "w", encoding='utf-8') as file:
for i in data:
file.write("\t".join(i) + '\n')
file.close()
print(zishu)
# import os
#
# file = "./data/11篇汇总txt_new.txt"
# file_t5 = "./data/11篇汇总txt_new_predict_t5.txt"
#
# try:
# with open(file, 'r', encoding="utf-8") as f:
# lines = [x.strip() for x in f if x.strip() != '']
# except:
# with open(file, 'r', encoding="gbk") as f:
# lines = [x.strip() for x in f if x.strip() != '']
#
# zishu = 0
# data = []
# for i in tqdm(lines):
#
# zishu += len(i)
# pre = just_show_sentence([i])
# data.append([i, pre])
#
# with open(file_t5, "w", encoding='utf-8') as file:
# for i in data:
# file.write("\t".join(i) + '\n')
# file.close()
# print(zishu)
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@ -497,13 +496,11 @@ if __name__ == '__main__':
# "强调轻资产经营, 更加重视经营风险的规避",
# "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择",
# "是时代对于青年们的深切呼唤"]
# text = ["随着经济的发展,人们生活水平的提高,环境:问题也日益突出。",
# "环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。",
# "所以对学生对应用仪器分析解决实际问题的能力要求很高。",
# "随着经济的发展,人们生活水平的提高,环境问题也日益突出。"]
#
# for i in text:
# print(just_show_sentence(i))
text = ["随着经济的发展,人们生活水平的提高,环境:问题也日益突出。",
"环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。",
"所以对学生对应用仪器分析解决实际问题的能力要求很高。",
"随着经济的发展,人们生活水平的提高,环境问题也日益突出。"]
print(just_show_sentence(text))
# print(just_show_sentence_top(text))
# print(just_show_chachong_random(text))

2
predict_tf_sim.py

@ -703,7 +703,7 @@ def just_show_csv_beam(file):
if __name__ == '__main__':
# file = "train_2842.txt"
# just_show(file)
text = ["随着经济的发展,人们生活水平的提高,环境问题也日益突出。"]
text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择,是时代对于青年们的深切呼唤"]
just_show_sentence(text)
# "简言之,她不好过,李四也别想好过!"
# s = "张三的对话"

2
redis_check_uuid.py

@ -84,4 +84,4 @@ def handle_query():
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=14001)
app.run(debug=False, host='0.0.0.0', port=14001)

2
run_app_nohub_search_redis.sh

@ -1 +1 @@
nohup python redis_check_uuid_mistral.py > myout.redis_check_uuid_mistral.logs 2>&1 &
nohup python redis_check_uuid.py > myout.redis_check_uuid.logs 2>&1 &

13
task_seq2seq_autotitle.py

@ -20,7 +20,6 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
@ -30,17 +29,19 @@ set_session(tf.Session(config=config)) # 此处不同
# 基本参数
maxlen = 256
batch_size = 8
batch_size = 32
steps_per_epoch = 20000
epochs = 10000
# bert配置
config_path = 'bert_config_dropout_0_3.json'
checkpoint_path = './chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt'
dict_path = './chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab.txt'
config_path = './bert_base_script_fintune_tf/config.json'
checkpoint_path = './bert_base_script_fintune_tf/bert_base_script_fintune_tf.ckpt'
dict_path = './bert_base_script_fintune_tf/vocab.txt'
# # 训练样本。THUCNews数据集,每个样本保存为一个txt。
# txts = glob.glob('/root/thuctc/THUCNews/*/*.txt')
file = "data/train_yy.txt"
file = "data/train_cat_data_4.txt"
try:
with open(file, 'r', encoding="utf-8") as f:
lines = [x.strip() for x in f if x.strip() != '']

8
task_seq2seq_t5.py

@ -16,7 +16,7 @@
# 补充了评测指标bleu、rouge-1、rouge-2、rouge-l
import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import json
import numpy as np
from tqdm import tqdm
@ -40,7 +40,7 @@ for gpu in gpus:
max_c_len = 128
max_t_len = 128
batch_size = 28
epochs = 10
epochs = 10000
# 模型路径
config_path = 'mt5/mt5_base_dropout_0_3_config.json'
@ -49,7 +49,7 @@ spm_path = 'mt5/mt5_base/sentencepiece_cn.model'
keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json'
file = "data/train_new/train_yy.txt"
file = "data/train_yy_zong_sim_99.txt"
try:
with open(file, 'r', encoding="utf-8") as f:
lines = [x.strip() for x in f if x.strip() != '']
@ -205,7 +205,7 @@ class Evaluator(keras.callbacks.Callback):
# 保存最优
if logs['loss'] <= self.lowest:
self.lowest = logs['loss']
model.save_weights('./output_t5/best_model_t5_0724.weights')
model.save_weights('./output_t5/best_model_t5_zong_sim_99.weights')
# 演示效果7
just_show()

Loading…
Cancel
Save