Browse Source

代码优化,不做功能改变,增加配置文件,把模型加载做成类

master
majiahui@haimaqingfan.com 2 years ago
parent
commit
dc1e8fe963
  1. 32
      config/predict_sim_config.py
  2. 45
      config/predict_t5_config.py
  3. 61
      flask_predict_no_batch_t5.py
  4. 5
      predict_drop_weight_sim.py
  5. 72
      predict_sim.py
  6. 95
      predict_t5.py

32
config/predict_sim_config.py

@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/3/27 10:23
@Author :
@FileName:
@Software:
@Describe:
"""
import sys
import os
pre_model_path = {
"simbert": {
"linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/chinese_roberta_wwm_ext_L-12_H-768_A-12",
"win32": r"E:\pycharm_workspace\premodel\keras\chinese_roberta_wwm_ext_L-12_H-768_A-12"
},
}
class DropSimBertConfig:
def __init__(self):
self.sys_platform = sys.platform
self.premodel_path = pre_model_path["simbert"][self.sys_platform]
self.config_path = os.path.join(self.premodel_path, 'bert_config.json')
self.checkpoint_path = os.path.join(self.premodel_path, 'bert_model.ckpt')
self.dict_path = './config/vocab_drop.txt'
self.savemodel_path = "./output_simbert_yy/best_simbertmodel_dropout_datasim_yinhao.weights"
self.maxlen = 120
self.cuda_id = "1"

45
config/predict_t5_config.py

@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/3/27 10:23
@Author :
@FileName:
@Software:
@Describe:
"""
import sys
import os
pre_model_path = {
"t5": {
"linux": "/home/zc-nlp-zyp/work_file/ssd_data/模型库/预训练模型集合/keras/mt5/mt5_base",
"win32": r"E:\pycharm_workspace\premodel\keras\mt5\mt5_base"
},
}
class DropT5Config:
def __init__(self):
self.sys_platform = sys.platform
self.premodel_path = pre_model_path["t5"][self.sys_platform]
self.config_path = os.path.join(self.premodel_path, 'mt5_base_config.json')
self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000')
self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model')
self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json')
self.savemodel_path = "./output_t5/best_model_t5_zong_sim_99.weights"
self.maxlen = 256
self.cuda_id = "1"
class MultipleResultsDropT5Config:
def __init__(self):
self.sys_platform = sys.platform
self.premodel_path = pre_model_path["t5"][self.sys_platform]
self.config_path = os.path.join(self.premodel_path, 'mt5_base_config.json')
self.checkpoint_path = os.path.join(self.premodel_path, 'model.ckpt-1000000')
self.spm_path = os.path.join(self.premodel_path, 'sentencepiece_cn.model')
self.keep_tokens_path = os.path.join(self.premodel_path, 'sentencepiece_cn_keep_tokens.json')
self.savemodel_path = "./output_t5/best_model_t5_dropout_0_3.weights"
self.maxlen = 256
self.cuda_id = "1"

61
flask_predict_no_batch_t5.py

@ -1,13 +1,13 @@
import os import os
# os.environ["TF_KERAS"] = "1" from config.predict_t5_config import DropT5Config
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" config = DropT5Config()
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id
from flask import Flask, jsonify from flask import Flask, jsonify
from flask import request from flask import request
# from linshi import autotitle # from linshi import autotitle
import requests import requests
from flask import request from predict_t5 import GenerateModel, AutoTitle
from predict_t5 import autotitle
import redis import redis
import uuid import uuid
import json import json
@ -15,7 +15,6 @@ from threading import Thread
import time import time
import re import re
pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1) pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1)
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
@ -27,13 +26,19 @@ app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False app.config["JSON_AS_ASCII"] = False
import logging import logging
pattern = r"[。]" pattern = r"[。]"
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
fuhao_end_sentence = ["","","","",""] fuhao_end_sentence = ["", "", "", "", ""]
config = { generatemodel = GenerateModel(config.config_path,
"batch_szie": 1000 config.checkpoint_path,
} config.spm_path,
config.keep_tokens_path,
config.maxlen,
config.savemodel_path)
encoder, decoder, model, tokenizer = generatemodel.device_setup()
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120)
def get_dialogs_index(line: str): def get_dialogs_index(line: str):
@ -56,7 +61,7 @@ def get_dialogs_index(line: str):
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num):
fuhao = ["","","",""] fuhao = ["", "", "", ""]
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
text_1 = text[:120] text_1 = text[:120]
text_2 = text[120:] text_2 = text[120:]
@ -64,7 +69,7 @@ def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num):
if text_2 == "": if text_2 == "":
chulipangban_return_list.append([text_1, snetence_id, short_num]) chulipangban_return_list.append([text_1, snetence_id, short_num])
return chulipangban_return_list return chulipangban_return_list
for i in range(len(text_1)-1, -1, -1): for i in range(len(text_1) - 1, -1, -1):
if text_1[i] in fuhao: if text_1[i] in fuhao:
if i in dialogs_index: if i in dialogs_index:
continue continue
@ -72,8 +77,8 @@ def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num):
text_1_new += text_1[i] text_1_new += text_1[i]
chulipangban_return_list.append([text_1_new, snetence_id, short_num]) chulipangban_return_list.append([text_1_new, snetence_id, short_num])
if text_2 != "": if text_2 != "":
if i+1 != 120: if i + 1 != 120:
text_2 = text_1[i+1:] + text_2 text_2 = text_1[i + 1:] + text_2
break break
# else: # else:
# chulipangban_return_list.append(text_1) # chulipangban_return_list.append(text_1)
@ -121,18 +126,7 @@ def chulipangban_test_1(snetence_id, text):
return sentence_batch_list return sentence_batch_list
def paragraph_test_(text:list, text_new:list): def paragraph_test(texts: dict):
for i in range(len(text)):
text = chulipangban_test_1(text, i)
text = "".join(text)
text_new.append(text)
# text_new_str = "".join(text_new)
return text_new
def paragraph_test(texts:dict):
text_new = [] text_new = []
for i, text in texts.items(): for i, text in texts.items():
text_list = chulipangban_test_1(i, text) text_list = chulipangban_test_1(i, text)
@ -159,6 +153,7 @@ def batch_data_process(text_list):
sentence_batch_list.append(sentence_batch_one) sentence_batch_list.append(sentence_batch_one)
return sentence_batch_list return sentence_batch_list
def batch_predict(batch_data_list): def batch_predict(batch_data_list):
''' '''
一个bacth数据预测 一个bacth数据预测
@ -173,7 +168,7 @@ def batch_predict(batch_data_list):
batch_data_snetence_id_list.append(i[1:]) batch_data_snetence_id_list.append(i[1:])
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list)
batch_pre_data_list = batch_data_text_list batch_pre_data_list = batch_data_text_list
for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list):
batch_data_list_new.append([text] + sentence_id) batch_data_list_new.append([text] + sentence_id)
return batch_data_list_new return batch_data_list_new
@ -238,6 +233,7 @@ def main(text: dict):
return_list = predict_data_post_processing(text_list_new) return_list = predict_data_post_processing(text_list_new)
return return_list return return_list
@app.route('/droprepeat/', methods=['POST']) @app.route('/droprepeat/', methods=['POST'])
def sentence(): def sentence():
print(request.remote_addr) print(request.remote_addr)
@ -246,7 +242,6 @@ def sentence():
print("原始语句" + str(texts)) print("原始语句" + str(texts))
# question = question.strip('。、!??') # question = question.strip('。、!??')
if isinstance(texts, dict): if isinstance(texts, dict):
texts_list = [] texts_list = []
y_pred_label_list = [] y_pred_label_list = []
@ -264,12 +259,14 @@ def sentence():
texts_list = main(texts) texts_list = main(texts)
return_text = {"texts": texts_list, "probabilities": None, "status_code": True} return_text = {"texts": texts_list, "probabilities": None, "status_code": True}
else: else:
return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False}
return jsonify(return_text) return jsonify(return_text)
def classify(): # 调用模型,设置最大batch_size def classify(): # 调用模型,设置最大batch_size
while True: while True:
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(3)
continue continue
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
data_dict_path = json.loads(query) data_dict_path = json.loads(query)
@ -291,8 +288,8 @@ def classify(): # 调用模型,设置最大batch_size
texts_list = main(texts) texts_list = main(texts)
else: else:
texts_list = [] texts_list = []
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
redis_.srem(db_key_querying, query_id) redis_.srem(db_key_querying, query_id)
load_result_path = "./new_data_logs/{}.json".format(query_id) load_result_path = "./new_data_logs/{}.json".format(query_id)
with open(load_result_path, 'w', encoding='utf8') as f2: with open(load_result_path, 'w', encoding='utf8') as f2:
@ -319,9 +316,9 @@ def handle_query():
# ensure_ascii=False才能输入中文,否则是Unicode字符 # ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观 # indent=2 JSON数据的缩进,美观
json.dump(d, f2, ensure_ascii=False, indent=4) json.dump(d, f2, ensure_ascii=False, indent=4)
redis_.rpush(db_key_query, json.dumps({"id":id_, "path": load_request_path})) # 加入redis redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
redis_.sadd(db_key_querying, id_) redis_.sadd(db_key_querying, id_)
return_text = {"texts": {'id': id_,}, "probabilities": None, "status_code": 200} return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
print("ok") print("ok")
else: else:
return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401} return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401}

5
predict_drop_weight_sim.py

@ -765,9 +765,8 @@ def paragraph_test(text, text_new):
if __name__ == '__main__': if __name__ == '__main__':
text = ["所以对学生对应用仪器分析解决实际问题的能力要求很高。","随着经济的发展,人们生活水平的提高,环境问题也日益突出。"]
text = ["在malpezzi研究的日本租赁市场中[23] , 可以看出日本的租赁住房市场主要以规范化经营为主, 强调轻资产经营, 更加重视经营风险的规避"] print(just_show_sentence(text))
print(type(just_show_sentence(text)))
# is_novel = False # is_novel = False
# path = "./data/700条论文测试.xlsx" # path = "./data/700条论文测试.xlsx"

72
predict_sim.py

@ -1,8 +1,11 @@
#! -*- coding: utf-8 -*- #! -*- coding: utf-8 -*-
import os import os
from config.predict_sim_config import DropSimBertConfig
config = DropSimBertConfig()
# os.environ["TF_KERAS"] = "1" # os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0" os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id
import glob import glob
import random import random
from tqdm import tqdm from tqdm import tqdm
@ -21,9 +24,9 @@ import tensorflow as tf
from keras.backend import set_session from keras.backend import set_session
config = tf.ConfigProto() tfconfig = tf.ConfigProto()
config.gpu_options.allow_growth = True tfconfig.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) # 此处不同 set_session(tf.Session(config=tfconfig)) # 此处不同
global graph global graph
graph = tf.get_default_graph() graph = tf.get_default_graph()
sess = tf.Session(graph=graph) sess = tf.Session(graph=graph)
@ -78,14 +81,12 @@ class TotalLoss(Loss):
class GenerateModel(object): class GenerateModel(object):
def __init__(self): def __init__(self, config_path, checkpoint_path, dict_path, maxlen, savemodel_path):
self.config_path = config_path
self.epoch_acc_vel = 0 self.checkpoint_path = checkpoint_path
self.config_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_config.json' self.dict_path = dict_path
self.checkpoint_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/bert_model.ckpt' self.maxlen = maxlen
self.dict_path = r'./chinese_roberta_wwm_ext_L-12_H-768_A-12/vocab_drop.txt' self.savemodel_path = savemodel_path
self.maxlen = 120
self.novel_maxlen = 60
def device_setup(self): def device_setup(self):
token_dict, keep_tokens = load_vocab( token_dict, keep_tokens = load_vocab(
@ -122,7 +123,7 @@ class GenerateModel(object):
outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs) outputs = TotalLoss([2, 3])(bert.model.inputs + bert.model.outputs)
model = keras.models.Model(bert.model.inputs, outputs) model = keras.models.Model(bert.model.inputs, outputs)
path_model = './output_simbert_yy/best_simbertmodel_datasim_yinhao.weights' path_model = self.savemodel_path
model.load_weights(path_model) model.load_weights(path_model)
return encoder,seq2seq, tokenizer return encoder,seq2seq, tokenizer
@ -626,12 +627,6 @@ class AutoTitle(AutoRegressiveDecoder):
else: else:
return sentence_list[0] return sentence_list[0]
generatemodel = GenerateModel()
encoder,seq2seq, tokenizer = generatemodel.device_setup()
autotitle = AutoTitle(seq2seq, tokenizer, start_id=None, end_id=tokenizer._token_end_id, maxlen=120)
def just_show(file): def just_show(file):
data = [] data = []
@ -708,27 +703,34 @@ def just_show_csv_beam(file):
if __name__ == '__main__': if __name__ == '__main__':
# text = ["强调轻资产“经营”, 更加重视“营风险”的规避", "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择", "是时代对于青年们的深切呼唤"] generatemodel = GenerateModel(config.config_path,
# print(just_show_sentence(text)) config.checkpoint_path,
config.dict_path,
config.maxlen,
config.savemodel_path)
encoder, seq2seq, tokenizer = generatemodel.device_setup()
autotitle = AutoTitle(seq2seq, tokenizer, start_id=None, end_id=tokenizer._token_end_id, maxlen=120)
text = ["随着经济的发展,人们生活水平的提高,环境问题也日益突出。"]
print(just_show_sentence(text))
# #
# print(just_show_sentence_batch(text)) # print(just_show_sentence_batch(text))
# print(type(just_show_sentence_batch(text))) # print(type(just_show_sentence_batch(text)))
path = "./data/700条论文测试.xlsx" # path = "./data/700条论文测试.xlsx"
df_list = pd.read_excel(path).values.tolist() # df_list = pd.read_excel(path).values.tolist()
#
df_list_new = [] # df_list_new = []
print(len(df_list)) # print(len(df_list))
for i in tqdm(df_list): # for i in tqdm(df_list):
try: # try:
pre = just_show_sentence([i[0]]) # pre = just_show_sentence([i[0]])
df_list_new.append([i[0], i[1]] + [pre]) # df_list_new.append([i[0], i[1]] + [pre])
except: # except:
print(i[0]) # print(i[0])
continue # continue
df = pd.DataFrame(df_list_new) # df = pd.DataFrame(df_list_new)
df.to_excel("./data/700条论文测试_19.xlsx", index=None) # df.to_excel("./data/700条论文测试_19.xlsx", index=None)

95
predict_t5.py

@ -9,9 +9,6 @@
""" """
#! -*- coding: utf-8 -*- #! -*- coding: utf-8 -*-
import os
# os.environ["TF_KERAS"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import glob import glob
from numpy import random from numpy import random
random.seed(1001) random.seed(1001)
@ -19,7 +16,6 @@ from tqdm import tqdm
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import json import json
import numpy as np
from tqdm import tqdm from tqdm import tqdm
from bert4keras.backend import keras, K from bert4keras.backend import keras, K
from bert4keras.layers import Loss from bert4keras.layers import Loss
@ -32,12 +28,10 @@ from keras.models import Model
# from rouge import Rouge # pip install rouge # from rouge import Rouge # pip install rouge
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction # from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import tensorflow as tf import tensorflow as tf
from keras.backend import set_session from keras.backend import set_session
config = tf.ConfigProto() tfconfig = tf.ConfigProto()
config.gpu_options.allow_growth = True tfconfig.gpu_options.allow_growth = True
set_session(tf.Session(config=config)) # 此处不同 set_session(tf.Session(config=tfconfig)) # 此处不同
global graph global graph
graph = tf.get_default_graph() graph = tf.get_default_graph()
sess = tf.Session(graph=graph) sess = tf.Session(graph=graph)
@ -52,14 +46,13 @@ set_session(sess)
# 基本参数 # 基本参数
class GenerateModel(object): class GenerateModel(object):
def __init__(self): def __init__(self, config_path, checkpoint_path, spm_path, keep_tokens_path, maxlen, savemodel_path):
self.config_path = config_path
self.epoch_acc_vel = 0 self.checkpoint_path = checkpoint_path
self.config_path = 'mt5/mt5_base/mt5_base_config.json' self.spm_path = spm_path
self.checkpoint_path = 'mt5/mt5_base/model.ckpt-1000000' self.keep_tokens_path = keep_tokens_path
self.spm_path = 'mt5/mt5_base/sentencepiece_cn.model' self.maxlen = maxlen
self.keep_tokens_path = 'mt5/mt5_base/sentencepiece_cn_keep_tokens.json' self.savemodel_path = savemodel_path
self.maxlen = 256
def device_setup(self): def device_setup(self):
tokenizer = SpTokenizer(self.spm_path, token_start=None, token_end='</s>') tokenizer = SpTokenizer(self.spm_path, token_start=None, token_end='</s>')
@ -85,7 +78,7 @@ class GenerateModel(object):
output = CrossEntropy(1)([model.inputs[1], model.outputs[0]]) output = CrossEntropy(1)([model.inputs[1], model.outputs[0]])
model = Model(model.inputs, output) model = Model(model.inputs, output)
path_model = "output_t5/best_model_t5.weights" path_model = self.savemodel_path
model.load_weights(path_model) model.load_weights(path_model)
return encoder, decoder, model, tokenizer return encoder, decoder, model, tokenizer
@ -131,7 +124,7 @@ class Beamdataone(object):
self.inputs_vector = 0 self.inputs_vector = 0
def text_2_textids(self,text): def text_2_textids(self,text):
token_ids, segment_ids = self.tokenizer.encode(text[0], maxlen=120) token_ids, segment_ids = self.tokenizer.encode(text[0], maxlen=self.maxlen)
self.text_ids = [token_ids] self.text_ids = [token_ids]
def add_data(self, step, output_scores): def add_data(self, step, output_scores):
@ -217,6 +210,11 @@ class AutoTitle(AutoRegressiveDecoder):
self.end_id = end_id self.end_id = end_id
self.minlen = minlen self.minlen = minlen
self.models = {} self.models = {}
self.chinese_sign = {
",":"",
":": "",
";": "",
}
if start_id is None: if start_id is None:
self.first_output_ids = np.empty((1, 0), dtype=int) self.first_output_ids = np.empty((1, 0), dtype=int)
else: else:
@ -246,7 +244,7 @@ class AutoTitle(AutoRegressiveDecoder):
c_encoded = inputs[0] c_encoded = inputs[0]
with graph.as_default(): with graph.as_default():
K.set_session(sess) K.set_session(sess)
nodes = self.last_token(decoder).predict([c_encoded, output_ids]) nodes = self.last_token(self.decoder).predict([c_encoded, output_ids])
return nodes return nodes
def predict_batch(self, inputs): def predict_batch(self, inputs):
@ -259,14 +257,6 @@ class AutoTitle(AutoRegressiveDecoder):
nodes = self.decoder.predict([token_ids, output_ids]) nodes = self.decoder.predict([token_ids, output_ids])
return nodes return nodes
def data_generator(self, token_ids, output_ids):
batch_token_ids = []
for i,j in zip(token_ids, output_ids):
batch_token_ids = sequence_padding(token_ids)
batch_segment_ids = sequence_padding(output_ids)
return batch_token_ids, batch_segment_ids
def beam_search_batch( def beam_search_batch(
self, self,
@ -395,15 +385,17 @@ class AutoTitle(AutoRegressiveDecoder):
def generate(self, text, topk=5): def generate(self, text, topk=5):
c_token_ids, _ = tokenizer.encode(text, maxlen=120) c_token_ids, _ = self.tokenizer.encode(text, maxlen=self.maxlen)
with graph.as_default(): with graph.as_default():
K.set_session(sess) K.set_session(sess)
c_encoded = encoder.predict(np.array([c_token_ids]))[0] c_encoded = self.encoder.predict(np.array([c_token_ids]))[0]
output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search
return tokenizer.decode([int(i) for i in output_ids]) return_text = self.tokenizer.decode([int(i) for i in output_ids])
return_text = "".join([self.chinese_sign[i] if i in self.chinese_sign else i for i in return_text])
return return_text
def generate_random(self, text, n=30, topp=0.9): def generate_random(self, text, n=30, topp=0.9):
c_token_ids, _ = self.tokenizer.encode(text, maxlen=120) c_token_ids, _ = self.tokenizer.encode(text, maxlen=self.maxlen)
with graph.as_default(): with graph.as_default():
K.set_session(sess) K.set_session(sess)
c_encoded = self.encoder.predict(np.array([c_token_ids]))[0] c_encoded = self.encoder.predict(np.array([c_token_ids]))[0]
@ -418,13 +410,6 @@ class AutoTitle(AutoRegressiveDecoder):
return output_str return output_str
generatemodel = GenerateModel()
encoder, decoder, model, tokenizer = generatemodel.device_setup()
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120)
def just_show_sentence(file): def just_show_sentence(file):
""" """
@param file:list @param file:list
@ -441,6 +426,18 @@ def just_show_sentence_batch(file: list) -> object:
if __name__ == '__main__': if __name__ == '__main__':
import os
from config.predict_t5_config import DropT5Config
config = DropT5Config()
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id
generatemodel = GenerateModel(config.config_path,
config.checkpoint_path,
config.spm_path,
config.keep_tokens_path,
config.maxlen,
config.savemodel_path)
encoder, decoder, model, tokenizer = generatemodel.device_setup()
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=256)
# file = "train_2842.txt" # file = "train_2842.txt"
# just_show(file) # just_show(file)
# text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是自身发展的必然选择,是时代对于青年们的深切呼唤"] # text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是自身发展的必然选择,是时代对于青年们的深切呼唤"]
@ -493,16 +490,20 @@ if __name__ == '__main__':
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ #++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
text = ["'李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会", # text = ["'李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会",
"历史和当下都证明,创新是民族生存、“发展的不竭源泉”,是是自身发展的必然选择", # "历史和当下都证明,创新是民族生存、“发展的不竭源泉”,是是自身发展的必然选择",
"自身发展的必然选择", # "自身发展的必然选择",
"强调轻资产经营, 更加重视经营风险的规避", # "强调轻资产经营, 更加重视经营风险的规避",
"历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择", # "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择",
"是时代对于青年们的深切呼唤"] # "是时代对于青年们的深切呼唤"]
# text = ["基本消除“热桥”影响。"] text = ["随着经济的发展,人们生活水平的提高,环境:问题也日益突出。",
"环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。",
"所以对学生对应用仪器分析解决实际问题的能力要求很高。",
"随着经济的发展,人们生活水平的提高,环境问题也日益突出。"]
print(just_show_sentence(text)) print(just_show_sentence(text))
# print(just_show_sentence_top(text)) # print(just_show_sentence_top(text))
# print(just_show_chachong_random(text)) # print(just_show_chachong_random(text))
# print(tokenizer.encode("\"", maxlen=120)) # print(tokenizer.encode("\"", maxlen=120))
# print(just_show_sentence_batch(text)) # print(just_show_sentence_batch(text))
# myout.flask_predict_no_batch_t5.logs
Loading…
Cancel
Save