
29 changed files with 2014 additions and 100 deletions
@ -0,0 +1,9 @@ |
|||
def has_numbers(input_string): |
|||
return any(char.isdigit() for char in input_string) |
|||
|
|||
# 示例用法 |
|||
input_str = "Hello, 123!" |
|||
if has_numbers(input_str): |
|||
print("字符串中包含数字") |
|||
else: |
|||
print("字符串中不包含数字") |
@ -0,0 +1,19 @@ |
|||
def is_contains_(str): |
|||
stack = [] |
|||
dict = {"]": "[", "}": "{", ")": "(", "”": "”", "’": "‘", "》": "《"} |
|||
for char in str: |
|||
if char in dict.values(): |
|||
stack.append(char) |
|||
elif char in dict.keys(): |
|||
if stack == [] or dict[char] != stack.pop(): |
|||
return False |
|||
else: |
|||
continue |
|||
if stack == []: |
|||
return True |
|||
else: |
|||
return False |
|||
|
|||
a = "d(a)a" |
|||
|
|||
print(is_contains_(a)) |
@ -0,0 +1,9 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/27 16:51 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
@ -0,0 +1,299 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/1/31 19:02 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
import os |
|||
# os.environ["TF_KERAS"] = "1" |
|||
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|||
import json |
|||
import numpy as np |
|||
from bert4keras.backend import keras, set_gelu |
|||
from bert4keras.tokenizers import Tokenizer, load_vocab |
|||
from bert4keras.models import build_transformer_model |
|||
from bert4keras.optimizers import Adam, extend_with_piecewise_linear_lr |
|||
from bert4keras.snippets import sequence_padding, DataGenerator |
|||
from bert4keras.snippets import open |
|||
from keras.layers import Lambda, Dense |
|||
import tensorflow as tf |
|||
from keras.backend import set_session |
|||
from sklearn.metrics.pairwise import cosine_similarity |
|||
from rouge import Rouge # pip install rouge |
|||
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
|||
from tqdm import tqdm |
|||
import jieba |
|||
from gensim.models import KeyedVectors, word2vec, Word2Vec |
|||
import random |
|||
import difflib |
|||
import re |
|||
|
|||
config = tf.ConfigProto() |
|||
config.gpu_options.allow_growth = True |
|||
set_session(tf.Session(config=config)) # 此处不同 |
|||
|
|||
class Word2vecModel: |
|||
def __init__(self): |
|||
self.path = "E:\pycharm_workspace\查重分析\word2vec_model\\word2vec_add_new_18.model" |
|||
self.model = Word2Vec.load(self.path) |
|||
|
|||
def word2vec_res(self,seg_0_list, seg_1_list): |
|||
sentence_0_list = [] |
|||
sentence_1_list = [] |
|||
for i in seg_0_list: |
|||
a = self.model.wv[i] |
|||
sentence_0_list.append(a) |
|||
|
|||
for i in seg_1_list: |
|||
a = self.model.wv[i] |
|||
sentence_1_list.append(a) |
|||
|
|||
return sentence_0_list, sentence_1_list |
|||
|
|||
class Evaluator(keras.callbacks.Callback): |
|||
"""评估与保存 |
|||
""" |
|||
|
|||
def __init__(self): |
|||
self.rouge = Rouge() |
|||
self.smooth = SmoothingFunction().method1 |
|||
self.best_bleu = 0. |
|||
|
|||
# def on_epoch_end(self, epoch, logs=None): |
|||
# metrics = self.evaluate(valid_data) # 评测模型 |
|||
# if metrics['bleu'] > self.best_bleu: |
|||
# self.best_bleu = metrics['bleu'] |
|||
# model.save_weights('./best_model.weights') # 保存模型 |
|||
# metrics['best_bleu'] = self.best_bleu |
|||
# print('valid_data:', metrics) |
|||
|
|||
|
|||
def evaluate_t(self, data_1, data_2, topk=1): |
|||
total = 0 |
|||
rouge_1, rouge_2, rouge_l, bleu = 0, 0, 0, 0 |
|||
|
|||
scores = self.rouge.get_scores(hyps=[data_1], refs=[data_2]) |
|||
rouge_1 += scores[0]['rouge-1']['f'] |
|||
rouge_2 += scores[0]['rouge-2']['f'] |
|||
rouge_l += scores[0]['rouge-l']['f'] |
|||
bleu += sentence_bleu( |
|||
references=[data_1.split(' ')], |
|||
hypothesis=data_2.split(' '), |
|||
smoothing_function=self.smooth |
|||
) |
|||
# rouge_1 /= total |
|||
# rouge_2 /= total |
|||
# rouge_l /= total |
|||
# bleu /= total |
|||
return [rouge_1, rouge_2, rouge_l, bleu] |
|||
|
|||
class bertModel: |
|||
def __init__(self): |
|||
|
|||
# modelpath = "E:\pycharm_workspace\premodel\keras\chinese_simbert_L-12_H-768_A-12" |
|||
# modelpath = "E:\pycharm_workspace\premodel\keras\chinese_roberta_wwm_ext_L-12_H-768_A-12" |
|||
# modelpath = "E:\pycharm_workspace\premodel\keras\chinese_L-12_H-768_A-12" |
|||
modelpath = "/home/majiahui/project/models-llm/keras/chinese_L-12_H-768_A-12" |
|||
self.config_path = modelpath + r'/bert_config.json' |
|||
self.checkpoint_path = modelpath + r'/bert_model.ckpt' |
|||
self.dict_path = modelpath + r'/vocab.txt' |
|||
self.token_dict, self.keep_tokens = load_vocab( |
|||
dict_path=self.dict_path, |
|||
simplified=True, |
|||
startswith=['[PAD]', '[UNK]', '[CLS]', '[SEP]'], |
|||
) |
|||
self.tokenizer = Tokenizer(self.token_dict, do_lower_case=True) |
|||
self.buildmodel() |
|||
|
|||
|
|||
def buildmodel(self): |
|||
bert = build_transformer_model( |
|||
config_path=self.config_path, |
|||
checkpoint_path=self.checkpoint_path, |
|||
return_keras_model=False, |
|||
) |
|||
|
|||
output = Lambda(lambda x: x[:, 0], name='CLS-token')(bert.model.output) |
|||
self.model = keras.models.Model(bert.model.input, output) |
|||
self.model.summary() |
|||
|
|||
def predict(self,text): |
|||
batch_token_ids, batch_segment_ids = [], [] |
|||
token_ids, segment_ids = self.tokenizer.encode(text, maxlen=256) |
|||
batch_token_ids.append(token_ids) |
|||
batch_segment_ids.append(segment_ids) |
|||
return self.model.predict([batch_token_ids, batch_segment_ids]) |
|||
|
|||
def predict_batch(self,text_list): |
|||
batch_token_ids, batch_segment_ids = [], [] |
|||
|
|||
for t in text_list: |
|||
token_ids, segment_ids = self.tokenizer.encode(t, maxlen=256) |
|||
batch_token_ids.append(token_ids) |
|||
batch_segment_ids.append(segment_ids) |
|||
|
|||
batch_token_ids = sequence_padding(batch_token_ids) |
|||
batch_segment_ids = sequence_padding(batch_segment_ids) |
|||
return self.model.predict([batch_token_ids, batch_segment_ids]) |
|||
|
|||
def simbert(data_1, data_2): |
|||
pass |
|||
|
|||
def word2vec(): |
|||
pass |
|||
|
|||
def bleu(): |
|||
pass |
|||
|
|||
def bool_len_strsim(data_1, data_2): |
|||
str_sim_value = difflib.SequenceMatcher(None, data_1, data_2).quick_ratio() |
|||
if len(data_2) - len(data_1) < 0: |
|||
if len(data_2) / len(data_1) > 0.8: |
|||
num_yu = 1 - len(data_2) / len(data_1) |
|||
str_sim_value = 1 - str_sim_value * num_yu |
|||
else: |
|||
return False, "" |
|||
|
|||
if str_sim_value < 0.65: |
|||
return True, str_sim_value |
|||
else: |
|||
return False, "" |
|||
|
|||
|
|||
def has_numbers(input_string): |
|||
return any(char.isdigit() for char in input_string) |
|||
|
|||
|
|||
def bool_num(data_1, data_2): |
|||
bool_1 = has_numbers(data_1) |
|||
bool_2 = has_numbers(data_2) |
|||
if bool_1 == True and bool_2 == True: |
|||
return True |
|||
else: |
|||
return False |
|||
|
|||
def is_contains_english(str): |
|||
my_re = re.compile(r'[A-Za-z]', re.S) |
|||
res = re.findall(my_re, str) |
|||
if len(res): |
|||
return True |
|||
else: |
|||
return False |
|||
|
|||
|
|||
def is_contains_kongge(str): |
|||
if " " in str or "\t" in str: |
|||
return True |
|||
else: |
|||
return False |
|||
|
|||
if __name__ == '__main__': |
|||
file = "../data/train_yy_pre.txt" |
|||
# file = "../data/train_yy_zong_sim_99.txt" |
|||
model = bertModel() |
|||
eval_class = Evaluator() |
|||
data_new = [] |
|||
|
|||
data_1_list = [] |
|||
data_2_list = [] |
|||
|
|||
# word2vecmodel = Word2vecModel() |
|||
try: |
|||
with open(file, 'r', encoding="utf-8") as f: |
|||
lines = [x.strip() for x in f if x.strip() != ''] |
|||
except: |
|||
with open(file, 'r', encoding="gbk") as f: |
|||
lines = [x.strip() for x in f if x.strip() != ''] |
|||
|
|||
bertsim_list = [] |
|||
bleusim_list = [] |
|||
word2vecsim_list = [] |
|||
data_train_text = [] |
|||
|
|||
# random.shuffle(lines) |
|||
print(len(lines)) |
|||
for txt in tqdm(lines): |
|||
|
|||
text = txt.split('\t') |
|||
if len(text) == 3: |
|||
data_1 = text[0] |
|||
data_2 = text[2] |
|||
|
|||
# 判断是否包含数字 |
|||
bool_num_ = bool_num(data_1, data_2) |
|||
if bool_num_ == False: |
|||
continue |
|||
|
|||
# 判断是否包含英文 |
|||
# data_english_bool = is_contains_english(data_1) |
|||
# if data_english_bool == True: |
|||
# continue |
|||
|
|||
# 判断是否包含空格 |
|||
data_kongge_bool = is_contains_kongge(data_1) |
|||
if data_kongge_bool == True: |
|||
continue |
|||
|
|||
# 判断是否符合字符相似度标准 |
|||
bool_len_strsim_v, strsim = bool_len_strsim(data_1,data_2) |
|||
if bool_len_strsim_v == True: |
|||
continue |
|||
|
|||
# # 第一种方法 |
|||
# y1 = model.predict(data_1)[0] |
|||
# y2 = model.predict(data_2)[0] |
|||
# cos_sim = cosine_similarity(y1.reshape(1, -1), y2.reshape(1, -1)) |
|||
# # bertsim_list.append((cos_sim[0][0], strsim, data_1, data_2)) |
|||
# if cos_sim[0][0] > 0.9: |
|||
# cos_sim_bool = True |
|||
# else: |
|||
# cos_sim_bool = False |
|||
# |
|||
# if cos_sim_bool == False: |
|||
# continue |
|||
# |
|||
# data_new.append("\t".join([data_1, "to", data_2])) |
|||
|
|||
|
|||
# data_train_text.append("\t".join([data_1, "to", data_2])) |
|||
|
|||
# 第二种方法 |
|||
y = model.predict_batch([data_1, data_2]) |
|||
y1 = y[0] |
|||
y2 = y[1] |
|||
cos_sim = cosine_similarity(y1.reshape(1, -1), y2.reshape(1, -1)) |
|||
# bertsim_list.append((cos_sim[0][0], strsim, data_1, data_2)) |
|||
if cos_sim[0][0] > 0.9: |
|||
cos_sim_bool = True |
|||
else: |
|||
cos_sim_bool = False |
|||
|
|||
if cos_sim_bool == False: |
|||
continue |
|||
|
|||
data_new.append("\t".join([data_1, "to", data_2])) |
|||
|
|||
|
|||
|
|||
# bertsim_list.sort(reverse=True) |
|||
# with open("../data/tongji_len_strsim_nertsim_1.txt", "w", encoding="utf-8") as f: |
|||
# for i in bertsim_list: |
|||
# f.write(str(i[0])) |
|||
# f.write(str("\t")) |
|||
# f.write(str(i[1])) |
|||
# f.write(str("\t")) |
|||
# f.write(str(i[2])) |
|||
# f.write(str("\t")) |
|||
# f.write(str(i[3])) |
|||
# f.write("\n") |
|||
# print(len(data_train_text)) |
|||
fileName = '../data/train_new/train_yy_1.txt' |
|||
# fileName = '../data/train_new/train_yy.txt' |
|||
with open(fileName, 'w', encoding='utf-8') as f: |
|||
for i in data_new: |
|||
f.write(str(i) + '\n') |
|||
f.close() |
|||
|
@ -0,0 +1,293 @@ |
|||
import os |
|||
from config.predict_t5_config import MultipleResultsDropT5Config |
|||
t5config = MultipleResultsDropT5Config() |
|||
from config.predict_sim_config import DropSimBertConfig |
|||
simbertconfig = DropSimBertConfig() |
|||
os.environ["CUDA_VISIBLE_DEVICES"] = t5config.cuda_id |
|||
from flask import Flask, jsonify |
|||
from flask import request |
|||
from predict_t5 import (GenerateModel as T5GenerateModel, |
|||
AutoTitle as T5AutoTitle) |
|||
from predict_sim import (GenerateModel as SimBertGenerateModel, |
|||
AutoTitle as SimBertT5AutoTitle) |
|||
import json |
|||
from threading import Thread |
|||
import time |
|||
import re |
|||
import requests |
|||
|
|||
|
|||
app = Flask(__name__) |
|||
app.config["JSON_AS_ASCII"] = False |
|||
|
|||
import logging |
|||
pattern = r"[。]" |
|||
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") |
|||
fuhao_end_sentence = ["。",",","?","!","…"] |
|||
|
|||
t5generatemodel = T5GenerateModel(t5config.config_path, |
|||
t5config.checkpoint_path, |
|||
t5config.spm_path, |
|||
t5config.keep_tokens_path, |
|||
t5config.maxlen, |
|||
t5config.savemodel_path) |
|||
|
|||
encoder, decoder, model, tokenizer = t5generatemodel.device_setup() |
|||
t5autotitle = T5AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) |
|||
|
|||
simbertgeneratemodel = SimBertGenerateModel(simbertconfig.config_path, |
|||
simbertconfig.checkpoint_path, |
|||
simbertconfig.dict_path, |
|||
simbertconfig.maxlen, |
|||
simbertconfig.savemodel_path) |
|||
encoder, seq2seq, tokenizer = simbertgeneratemodel.device_setup() |
|||
simbertautotitle = SimBertT5AutoTitle(seq2seq, tokenizer, start_id=None, end_id=tokenizer._token_end_id, maxlen=120) |
|||
|
|||
|
|||
def requests_chatGPT(data): |
|||
res = requests.post('http://98.142.138.229:9999/chatgpt', data=data) |
|||
return res.json()['res'] |
|||
|
|||
def get_dialogs_index(line: str): |
|||
""" |
|||
获取对话及其索引 |
|||
:param line 文本 |
|||
:return dialogs 对话内容 |
|||
dialogs_index: 对话位置索引 |
|||
other_index: 其他内容位置索引 |
|||
""" |
|||
dialogs = re.finditer(RE_DIALOG, line) |
|||
dialogs_text = re.findall(RE_DIALOG, line) |
|||
dialogs_index = [] |
|||
for dialog in dialogs: |
|||
all_ = [i for i in range(dialog.start(), dialog.end())] |
|||
dialogs_index.extend(all_) |
|||
other_index = [i for i in range(len(line)) if i not in dialogs_index] |
|||
|
|||
return dialogs_text, dialogs_index, other_index |
|||
|
|||
|
|||
def chulichangju_1(text, chulipangban_return_list, short_num): |
|||
fuhao = [",","?","!","…"] |
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
text_1 = text[:120] |
|||
text_2 = text[120:] |
|||
text_1_new = "" |
|||
if text_2 == "": |
|||
chulipangban_return_list.append([text_1, short_num]) |
|||
return chulipangban_return_list |
|||
for i in range(len(text_1)-1, -1, -1): |
|||
if text_1[i] in fuhao: |
|||
if i in dialogs_index: |
|||
continue |
|||
text_1_new = text_1[:i] |
|||
text_1_new += text_1[i] |
|||
chulipangban_return_list.append([text_1_new, short_num]) |
|||
if text_2 != "": |
|||
if i+1 != 120: |
|||
text_2 = text_1[i+1:] + text_2 |
|||
break |
|||
# else: |
|||
# chulipangban_return_list.append(text_1) |
|||
if text_1_new == "": |
|||
chulipangban_return_list.append([text_1, short_num]) |
|||
if text_2 != "": |
|||
short_num += 1 |
|||
chulipangban_return_list = chulichangju_1(text_2, chulipangban_return_list, short_num) |
|||
return chulipangban_return_list |
|||
|
|||
|
|||
def chulipangban_test_1(text): |
|||
# 引号处理 |
|||
|
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
for dialogs_text_dan in dialogs_text: |
|||
text_dan_list = text.split(dialogs_text_dan) |
|||
if "。" in dialogs_text_dan: |
|||
dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") |
|||
text = dialogs_text_dan.join(text_dan_list) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
|
|||
sentence_list = text.split("。") |
|||
# sentence_list_new = [] |
|||
# for i in sentence_list: |
|||
# if i != "": |
|||
# sentence_list_new.append(i) |
|||
# sentence_list = sentence_list_new |
|||
sentence_batch_list = [] |
|||
sentence_batch_one = [] |
|||
sentence_batch_length = 0 |
|||
return_list = [] |
|||
for sentence in sentence_list: |
|||
if len(sentence) < 120: |
|||
sentence_batch_length += len(sentence) |
|||
sentence_batch_list.append([sentence, 0]) |
|||
# sentence_pre = autotitle.gen_synonyms_short(sentence) |
|||
# return_list.append(sentence_pre) |
|||
else: |
|||
|
|||
sentence_split_list = chulichangju_1(sentence,[], 0) |
|||
for sentence_short in sentence_split_list: |
|||
sentence_batch_list.append(sentence_short) |
|||
return sentence_batch_list |
|||
|
|||
|
|||
def paragraph_test(texts:str): |
|||
|
|||
|
|||
text_list = chulipangban_test_1(texts) |
|||
|
|||
|
|||
# text_new_str = "".join(text_new) |
|||
return text_list |
|||
|
|||
|
|||
def batch_data_process(text_list): |
|||
sentence_batch_length = 0 |
|||
sentence_batch_one = [] |
|||
sentence_batch_list = [] |
|||
|
|||
for sentence in text_list: |
|||
sentence_batch_length += len(sentence[0]) |
|||
sentence_batch_one.append(sentence) |
|||
if sentence_batch_length > 500: |
|||
sentence_batch_length = 0 |
|||
sentence_ = sentence_batch_one.pop(-1) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
sentence_batch_one = [] |
|||
sentence_batch_one.append(sentence_) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
return sentence_batch_list |
|||
|
|||
def batch_predict(batch_data_list): |
|||
''' |
|||
一个bacth数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
batch_data_list_new = [] |
|||
batch_data_text_list = [] |
|||
batch_data_snetence_id_list = [] |
|||
for i in batch_data_list: |
|||
batch_data_text_list.append(i[0]) |
|||
batch_data_snetence_id_list.append(i[1:]) |
|||
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) |
|||
batch_pre_data_list = batch_data_text_list |
|||
for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): |
|||
batch_data_list_new.append([text] + sentence_id) |
|||
|
|||
return batch_data_list_new |
|||
|
|||
|
|||
def one_predict(data_text): |
|||
''' |
|||
一个条数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
return_data_list = [] |
|||
if data_text[0] != "": |
|||
data_inputs = data_text[0].replace("&", "。") |
|||
prompt_list = ["请帮我改写一下这个句子", "请帮美化一下下面句子", "请帮我修改下面句子让这句话更完美"] |
|||
pre_data_list = [] |
|||
for i in prompt_list: |
|||
pre_data = requests_chatGPT( |
|||
data={ |
|||
'prompt':i, |
|||
'text':data_inputs |
|||
} |
|||
) |
|||
pre_data_list.append(pre_data) |
|||
modelclass_list = [t5autotitle, simbertautotitle] |
|||
for model in modelclass_list: |
|||
pre_data_list.append(model.generate(data_inputs)) |
|||
else: |
|||
pre_data_list = [""] * 5 |
|||
for pre_data in pre_data_list: |
|||
return_data_list.append([pre_data] + data_text[1:]) |
|||
|
|||
return return_data_list |
|||
|
|||
|
|||
def predict_data_post_processing(text_list, index): |
|||
text_list_sentence = [] |
|||
# text_list_sentence.append([text_list[0][0], text_list[0][1]]) |
|||
|
|||
for i in range(len(text_list)): |
|||
if text_list[i][index][2] != 0: |
|||
text_list_sentence[-1][0] += text_list[i][index][0] |
|||
else: |
|||
text_list_sentence.append([text_list[i][0], text_list[i][1]]) |
|||
|
|||
return_list = {} |
|||
sentence_one = [] |
|||
sentence_id = text_list_sentence[0][1] |
|||
for i in text_list_sentence: |
|||
if i[1] == sentence_id: |
|||
sentence_one.append(i[0]) |
|||
else: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
sentence_id = i[1] |
|||
sentence_one = [] |
|||
sentence_one.append(i[0]) |
|||
if sentence_one != []: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
return return_list |
|||
|
|||
|
|||
# def main(text:list): |
|||
# # text_list = paragraph_test(text) |
|||
# # batch_data = batch_data_process(text_list) |
|||
# # text_list = [] |
|||
# # for i in batch_data: |
|||
# # text_list.extend(i) |
|||
# # return_list = predict_data_post_processing(text_list) |
|||
# # return return_list |
|||
|
|||
def main(text: str): |
|||
text_list = paragraph_test(text) |
|||
text_list_new = [] |
|||
return_list = [] |
|||
for i in text_list: |
|||
pre_list = one_predict(i) |
|||
text_list_new.append(pre_list) |
|||
|
|||
for index in range(len(text_list_new[0])): |
|||
return_list.append(predict_data_post_processing(text_list_new, index)) |
|||
return return_list |
|||
|
|||
@app.route('/multiple_results_droprepeat/', methods=['POST']) |
|||
def sentence(): |
|||
print(request.remote_addr) |
|||
texts = request.json["texts"] |
|||
print("原始语句" + str(texts)) |
|||
# question = question.strip('。、!??') |
|||
|
|||
|
|||
if isinstance(texts, str): |
|||
texts_list = [] |
|||
y_pred_label_list = [] |
|||
position_list = [] |
|||
|
|||
# texts = texts.replace('\'', '\"') |
|||
if texts is None: |
|||
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} |
|||
return jsonify(return_text) |
|||
else: |
|||
texts_list = main(texts) |
|||
return_text = {"texts": texts_list, "probabilities": None, "status_code": True} |
|||
else: |
|||
return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} |
|||
return jsonify(return_text) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') |
|||
logging.basicConfig( |
|||
handlers=[fh], |
|||
level=logging.DEBUG, |
|||
format='%(asctime)s - %(levelname)s - %(message)s', |
|||
datefmt='%a, %d %b %Y %H:%M:%S', |
|||
) |
|||
app.run(host="0.0.0.0", port=14000, threaded=True, debug=False) |
@ -1,15 +0,0 @@ |
|||
{ |
|||
"texts": { |
|||
"0": "李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,李正旺你真是傻逼讪笑,挥手道,", |
|||
"1": "李正旺你真是傻逼讪笑,挥手道", |
|||
"2": "李正旺你真是傻逼讪笑,挥手道", |
|||
"3": "李正旺你真是傻逼讪笑,挥手道", |
|||
"4": "李正旺你真是傻逼讪笑,挥手道", |
|||
"5": "李正旺你真是傻逼讪笑,挥手道", |
|||
"6": "李正旺你真是傻逼讪笑,挥手道", |
|||
"7": "李正旺你真是傻逼讪笑,挥手道", |
|||
"8": "李正旺你真是傻逼讪笑,挥手李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手道李正旺你真是傻逼讪笑,挥手" |
|||
}, |
|||
"probabilities": null, |
|||
"status_code": 200 |
|||
} |
@ -0,0 +1,343 @@ |
|||
import os |
|||
from config.predict_t5_config import DropT5Config |
|||
|
|||
config = DropT5Config() |
|||
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id |
|||
from flask import Flask, jsonify |
|||
from flask import request |
|||
# from linshi import autotitle |
|||
import requests |
|||
from predict_t5 import GenerateModel, AutoTitle |
|||
import redis |
|||
import uuid |
|||
import json |
|||
from threading import Thread |
|||
import time |
|||
import re |
|||
import logging |
|||
|
|||
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 |
|||
filename='rewrite.log', |
|||
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 |
|||
# a是追加模式,默认如果不写的话,就是追加模式 |
|||
format= |
|||
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' |
|||
# 日志格式 |
|||
) |
|||
|
|||
pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=100, db=1) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
db_key_query = 'query' |
|||
db_key_querying = 'querying' |
|||
db_key_queryset = 'queryset' |
|||
batch_size = 32 |
|||
|
|||
app = Flask(__name__) |
|||
app.config["JSON_AS_ASCII"] = False |
|||
|
|||
import logging |
|||
|
|||
pattern = r"[。]" |
|||
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") |
|||
fuhao_end_sentence = ["。", ",", "?", "!", "…"] |
|||
|
|||
generatemodel = GenerateModel(config.config_path, |
|||
config.checkpoint_path, |
|||
config.spm_path, |
|||
config.keep_tokens_path, |
|||
config.maxlen, |
|||
config.savemodel_path) |
|||
encoder, decoder, model, tokenizer = generatemodel.device_setup() |
|||
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) |
|||
|
|||
|
|||
class log: |
|||
def __init__(self): |
|||
pass |
|||
|
|||
def log(*args, **kwargs): |
|||
format = '%Y/%m/%d-%H:%M:%S' |
|||
format_h = '%Y-%m-%d' |
|||
value = time.localtime(int(time.time())) |
|||
dt = time.strftime(format, value) |
|||
dt_log_file = time.strftime(format_h, value) |
|||
log_file = 'log_file/access-%s' % dt_log_file + ".log" |
|||
if not os.path.exists(log_file): |
|||
with open(os.path.join(log_file), 'w', encoding='utf-8') as f: |
|||
print(dt, *args, file=f, **kwargs) |
|||
else: |
|||
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: |
|||
print(dt, *args, file=f, **kwargs) |
|||
|
|||
|
|||
def get_dialogs_index(line: str): |
|||
""" |
|||
获取对话及其索引 |
|||
:param line 文本 |
|||
:return dialogs 对话内容 |
|||
dialogs_index: 对话位置索引 |
|||
other_index: 其他内容位置索引 |
|||
""" |
|||
dialogs = re.finditer(RE_DIALOG, line) |
|||
dialogs_text = re.findall(RE_DIALOG, line) |
|||
dialogs_index = [] |
|||
for dialog in dialogs: |
|||
all_ = [i for i in range(dialog.start(), dialog.end())] |
|||
dialogs_index.extend(all_) |
|||
other_index = [i for i in range(len(line)) if i not in dialogs_index] |
|||
|
|||
return dialogs_text, dialogs_index, other_index |
|||
|
|||
|
|||
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): |
|||
fuhao = [",", "?", "!", "…"] |
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
text_1 = text[:120] |
|||
text_2 = text[120:] |
|||
text_1_new = "" |
|||
if text_2 == "": |
|||
chulipangban_return_list.append([text_1, snetence_id, short_num]) |
|||
return chulipangban_return_list |
|||
for i in range(len(text_1) - 1, -1, -1): |
|||
if text_1[i] in fuhao: |
|||
if i in dialogs_index: |
|||
continue |
|||
text_1_new = text_1[:i] |
|||
text_1_new += text_1[i] |
|||
chulipangban_return_list.append([text_1_new, snetence_id, short_num]) |
|||
if text_2 != "": |
|||
if i + 1 != 120: |
|||
text_2 = text_1[i + 1:] + text_2 |
|||
break |
|||
# else: |
|||
# chulipangban_return_list.append(text_1) |
|||
if text_1_new == "": |
|||
chulipangban_return_list.append([text_1, snetence_id, short_num]) |
|||
if text_2 != "": |
|||
short_num += 1 |
|||
chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) |
|||
return chulipangban_return_list |
|||
|
|||
|
|||
def chulipangban_test_1(snetence_id, text): |
|||
# 引号处理 |
|||
|
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
for dialogs_text_dan in dialogs_text: |
|||
text_dan_list = text.split(dialogs_text_dan) |
|||
if "。" in dialogs_text_dan: |
|||
dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") |
|||
text = dialogs_text_dan.join(text_dan_list) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
|
|||
sentence_list = text.split("。") |
|||
# sentence_list_new = [] |
|||
# for i in sentence_list: |
|||
# if i != "": |
|||
# sentence_list_new.append(i) |
|||
# sentence_list = sentence_list_new |
|||
sentence_batch_list = [] |
|||
sentence_batch_one = [] |
|||
sentence_batch_length = 0 |
|||
return_list = [] |
|||
for sentence in sentence_list: |
|||
if len(sentence) < 120: |
|||
sentence_batch_length += len(sentence) |
|||
sentence_batch_list.append([sentence, snetence_id, 0]) |
|||
# sentence_pre = autotitle.gen_synonyms_short(sentence) |
|||
# return_list.append(sentence_pre) |
|||
else: |
|||
|
|||
sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) |
|||
for sentence_short in sentence_split_list: |
|||
sentence_batch_list.append(sentence_short) |
|||
return sentence_batch_list |
|||
|
|||
|
|||
def paragraph_test(texts: dict): |
|||
text_new = [] |
|||
for i, text in texts.items(): |
|||
text_list = chulipangban_test_1(i, text) |
|||
text_new.extend(text_list) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
return text_new |
|||
|
|||
|
|||
def batch_data_process(text_list): |
|||
sentence_batch_length = 0 |
|||
sentence_batch_one = [] |
|||
sentence_batch_list = [] |
|||
|
|||
for sentence in text_list: |
|||
sentence_batch_length += len(sentence[0]) |
|||
sentence_batch_one.append(sentence) |
|||
if sentence_batch_length > 500: |
|||
sentence_batch_length = 0 |
|||
sentence_ = sentence_batch_one.pop(-1) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
sentence_batch_one = [] |
|||
sentence_batch_one.append(sentence_) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
return sentence_batch_list |
|||
|
|||
|
|||
def batch_predict(batch_data_list): |
|||
''' |
|||
一个bacth数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
batch_data_list_new = [] |
|||
batch_data_text_list = [] |
|||
batch_data_snetence_id_list = [] |
|||
for i in batch_data_list: |
|||
batch_data_text_list.append(i[0]) |
|||
batch_data_snetence_id_list.append(i[1:]) |
|||
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) |
|||
batch_pre_data_list = batch_data_text_list |
|||
for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): |
|||
batch_data_list_new.append([text] + sentence_id) |
|||
|
|||
return batch_data_list_new |
|||
|
|||
|
|||
def one_predict(data_text): |
|||
''' |
|||
一个条数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
if data_text[0] != "": |
|||
data_inputs = data_text[0].replace("&", "。") |
|||
pre_data = autotitle.generate(data_inputs) |
|||
else: |
|||
pre_data = "" |
|||
data_new = [pre_data] + data_text[1:] |
|||
return data_new |
|||
|
|||
|
|||
def predict_data_post_processing(text_list): |
|||
text_list_sentence = [] |
|||
# text_list_sentence.append([text_list[0][0], text_list[0][1]]) |
|||
|
|||
for i in range(len(text_list)): |
|||
if text_list[i][2] != 0: |
|||
text_list_sentence[-1][0] += text_list[i][0] |
|||
else: |
|||
text_list_sentence.append([text_list[i][0], text_list[i][1]]) |
|||
|
|||
return_list = {} |
|||
sentence_one = [] |
|||
sentence_id = text_list_sentence[0][1] |
|||
for i in text_list_sentence: |
|||
if i[1] == sentence_id: |
|||
sentence_one.append(i[0]) |
|||
else: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
sentence_id = i[1] |
|||
sentence_one = [] |
|||
sentence_one.append(i[0]) |
|||
if sentence_one != []: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
return return_list |
|||
|
|||
|
|||
# def main(text:list): |
|||
# # text_list = paragraph_test(text) |
|||
# # batch_data = batch_data_process(text_list) |
|||
# # text_list = [] |
|||
# # for i in batch_data: |
|||
# # text_list.extend(i) |
|||
# # return_list = predict_data_post_processing(text_list) |
|||
# # return return_list |
|||
|
|||
def main(text: dict): |
|||
text_list = paragraph_test(text) |
|||
text_list_new = [] |
|||
for i in text_list: |
|||
pre = one_predict(i) |
|||
text_list_new.append(pre) |
|||
return_list = predict_data_post_processing(text_list_new) |
|||
return return_list |
|||
|
|||
|
|||
@app.route('/droprepeat/', methods=['POST']) |
|||
def sentence(): |
|||
print(request.remote_addr) |
|||
texts = request.json["texts"] |
|||
text_type = request.json["text_type"] |
|||
print("原始语句" + str(texts)) |
|||
# question = question.strip('。、!??') |
|||
|
|||
if isinstance(texts, dict): |
|||
texts_list = [] |
|||
y_pred_label_list = [] |
|||
position_list = [] |
|||
|
|||
# texts = texts.replace('\'', '\"') |
|||
if texts is None: |
|||
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} |
|||
return jsonify(return_text) |
|||
else: |
|||
assert text_type in ['focus', 'chapter'] |
|||
if text_type == 'focus': |
|||
texts_list = main(texts) |
|||
if text_type == 'chapter': |
|||
texts_list = main(texts) |
|||
return_text = {"texts": texts_list, "probabilities": None, "status_code": True} |
|||
else: |
|||
return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False} |
|||
return jsonify(return_text) |
|||
|
|||
|
|||
def classify(): # 调用模型,设置最大batch_size |
|||
while True: |
|||
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 |
|||
time.sleep(3) |
|||
continue |
|||
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text |
|||
data_dict_path = json.loads(query) |
|||
path = data_dict_path['path'] |
|||
# text_type = data_dict["text_type"] |
|||
|
|||
with open(path, encoding='utf8') as f1: |
|||
# 加载文件的对象 |
|||
data_dict = json.load(f1) |
|||
|
|||
query_id = data_dict['id'] |
|||
texts = data_dict["text"] |
|||
text_type = data_dict["text_type"] |
|||
|
|||
assert text_type in ['focus', 'chapter'] |
|||
if text_type == 'focus': |
|||
texts_list = main(texts) |
|||
elif text_type == 'chapter': |
|||
texts_list = main(texts) |
|||
else: |
|||
texts_list = [] |
|||
|
|||
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} |
|||
load_result_path = "./new_data_logs/{}.json".format(query_id) |
|||
|
|||
print("query_id: ", query_id) |
|||
print("load_result_path: ", load_result_path) |
|||
|
|||
with open(load_result_path, 'w', encoding='utf8') as f2: |
|||
# ensure_ascii=False才能输入中文,否则是Unicode字符 |
|||
# indent=2 JSON数据的缩进,美观 |
|||
json.dump(return_text, f2, ensure_ascii=False, indent=4) |
|||
debug_id_1 = 1 |
|||
redis_.set(query_id, load_result_path, 86400) |
|||
debug_id_2 = 2 |
|||
redis_.srem(db_key_querying, query_id) |
|||
debug_id_3 = 3 |
|||
log.log('start at', |
|||
'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( |
|||
query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) |
|||
|
|||
if __name__ == '__main__': |
|||
classify() |
|||
|
@ -0,0 +1,343 @@ |
|||
import os |
|||
from config.predict_t5_config import DropT5Config |
|||
|
|||
config = DropT5Config() |
|||
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id |
|||
from flask import Flask, jsonify |
|||
from flask import request |
|||
# from linshi import autotitle |
|||
import requests |
|||
from predict_t5 import GenerateModel, AutoTitle |
|||
import redis |
|||
import uuid |
|||
import json |
|||
from threading import Thread |
|||
import time |
|||
import re |
|||
import logging |
|||
|
|||
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 |
|||
filename='rewrite.log', |
|||
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 |
|||
# a是追加模式,默认如果不写的话,就是追加模式 |
|||
format= |
|||
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' |
|||
# 日志格式 |
|||
) |
|||
|
|||
pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=100, db=1) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
db_key_query = 'query' |
|||
db_key_querying = 'querying' |
|||
db_key_queryset = 'queryset' |
|||
batch_size = 32 |
|||
|
|||
app = Flask(__name__) |
|||
app.config["JSON_AS_ASCII"] = False |
|||
|
|||
import logging |
|||
|
|||
pattern = r"[。]" |
|||
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") |
|||
fuhao_end_sentence = ["。", ",", "?", "!", "…"] |
|||
|
|||
generatemodel = GenerateModel(config.config_path, |
|||
config.checkpoint_path, |
|||
config.spm_path, |
|||
config.keep_tokens_path, |
|||
config.maxlen, |
|||
config.savemodel_path) |
|||
encoder, decoder, model, tokenizer = generatemodel.device_setup() |
|||
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) |
|||
|
|||
|
|||
class log: |
|||
def __init__(self): |
|||
pass |
|||
|
|||
def log(*args, **kwargs): |
|||
format = '%Y/%m/%d-%H:%M:%S' |
|||
format_h = '%Y-%m-%d' |
|||
value = time.localtime(int(time.time())) |
|||
dt = time.strftime(format, value) |
|||
dt_log_file = time.strftime(format_h, value) |
|||
log_file = 'log_file/access-%s' % dt_log_file + ".log" |
|||
if not os.path.exists(log_file): |
|||
with open(os.path.join(log_file), 'w', encoding='utf-8') as f: |
|||
print(dt, *args, file=f, **kwargs) |
|||
else: |
|||
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: |
|||
print(dt, *args, file=f, **kwargs) |
|||
|
|||
|
|||
def get_dialogs_index(line: str): |
|||
""" |
|||
获取对话及其索引 |
|||
:param line 文本 |
|||
:return dialogs 对话内容 |
|||
dialogs_index: 对话位置索引 |
|||
other_index: 其他内容位置索引 |
|||
""" |
|||
dialogs = re.finditer(RE_DIALOG, line) |
|||
dialogs_text = re.findall(RE_DIALOG, line) |
|||
dialogs_index = [] |
|||
for dialog in dialogs: |
|||
all_ = [i for i in range(dialog.start(), dialog.end())] |
|||
dialogs_index.extend(all_) |
|||
other_index = [i for i in range(len(line)) if i not in dialogs_index] |
|||
|
|||
return dialogs_text, dialogs_index, other_index |
|||
|
|||
|
|||
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): |
|||
fuhao = [",", "?", "!", "…"] |
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
text_1 = text[:120] |
|||
text_2 = text[120:] |
|||
text_1_new = "" |
|||
if text_2 == "": |
|||
chulipangban_return_list.append([text_1, snetence_id, short_num]) |
|||
return chulipangban_return_list |
|||
for i in range(len(text_1) - 1, -1, -1): |
|||
if text_1[i] in fuhao: |
|||
if i in dialogs_index: |
|||
continue |
|||
text_1_new = text_1[:i] |
|||
text_1_new += text_1[i] |
|||
chulipangban_return_list.append([text_1_new, snetence_id, short_num]) |
|||
if text_2 != "": |
|||
if i + 1 != 120: |
|||
text_2 = text_1[i + 1:] + text_2 |
|||
break |
|||
# else: |
|||
# chulipangban_return_list.append(text_1) |
|||
if text_1_new == "": |
|||
chulipangban_return_list.append([text_1, snetence_id, short_num]) |
|||
if text_2 != "": |
|||
short_num += 1 |
|||
chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) |
|||
return chulipangban_return_list |
|||
|
|||
|
|||
def chulipangban_test_1(snetence_id, text): |
|||
# 引号处理 |
|||
|
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
for dialogs_text_dan in dialogs_text: |
|||
text_dan_list = text.split(dialogs_text_dan) |
|||
if "。" in dialogs_text_dan: |
|||
dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") |
|||
text = dialogs_text_dan.join(text_dan_list) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
|
|||
sentence_list = text.split("。") |
|||
# sentence_list_new = [] |
|||
# for i in sentence_list: |
|||
# if i != "": |
|||
# sentence_list_new.append(i) |
|||
# sentence_list = sentence_list_new |
|||
sentence_batch_list = [] |
|||
sentence_batch_one = [] |
|||
sentence_batch_length = 0 |
|||
return_list = [] |
|||
for sentence in sentence_list: |
|||
if len(sentence) < 120: |
|||
sentence_batch_length += len(sentence) |
|||
sentence_batch_list.append([sentence, snetence_id, 0]) |
|||
# sentence_pre = autotitle.gen_synonyms_short(sentence) |
|||
# return_list.append(sentence_pre) |
|||
else: |
|||
|
|||
sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) |
|||
for sentence_short in sentence_split_list: |
|||
sentence_batch_list.append(sentence_short) |
|||
return sentence_batch_list |
|||
|
|||
|
|||
def paragraph_test(texts: dict): |
|||
text_new = [] |
|||
for i, text in texts.items(): |
|||
text_list = chulipangban_test_1(i, text) |
|||
text_new.extend(text_list) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
return text_new |
|||
|
|||
|
|||
def batch_data_process(text_list): |
|||
sentence_batch_length = 0 |
|||
sentence_batch_one = [] |
|||
sentence_batch_list = [] |
|||
|
|||
for sentence in text_list: |
|||
sentence_batch_length += len(sentence[0]) |
|||
sentence_batch_one.append(sentence) |
|||
if sentence_batch_length > 500: |
|||
sentence_batch_length = 0 |
|||
sentence_ = sentence_batch_one.pop(-1) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
sentence_batch_one = [] |
|||
sentence_batch_one.append(sentence_) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
return sentence_batch_list |
|||
|
|||
|
|||
def batch_predict(batch_data_list): |
|||
''' |
|||
一个bacth数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
batch_data_list_new = [] |
|||
batch_data_text_list = [] |
|||
batch_data_snetence_id_list = [] |
|||
for i in batch_data_list: |
|||
batch_data_text_list.append(i[0]) |
|||
batch_data_snetence_id_list.append(i[1:]) |
|||
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) |
|||
batch_pre_data_list = batch_data_text_list |
|||
for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): |
|||
batch_data_list_new.append([text] + sentence_id) |
|||
|
|||
return batch_data_list_new |
|||
|
|||
|
|||
def one_predict(data_text): |
|||
''' |
|||
一个条数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
if data_text[0] != "": |
|||
data_inputs = data_text[0].replace("&", "。") |
|||
pre_data = autotitle.generate(data_inputs) |
|||
else: |
|||
pre_data = "" |
|||
data_new = [pre_data] + data_text[1:] |
|||
return data_new |
|||
|
|||
|
|||
def predict_data_post_processing(text_list): |
|||
text_list_sentence = [] |
|||
# text_list_sentence.append([text_list[0][0], text_list[0][1]]) |
|||
|
|||
for i in range(len(text_list)): |
|||
if text_list[i][2] != 0: |
|||
text_list_sentence[-1][0] += text_list[i][0] |
|||
else: |
|||
text_list_sentence.append([text_list[i][0], text_list[i][1]]) |
|||
|
|||
return_list = {} |
|||
sentence_one = [] |
|||
sentence_id = text_list_sentence[0][1] |
|||
for i in text_list_sentence: |
|||
if i[1] == sentence_id: |
|||
sentence_one.append(i[0]) |
|||
else: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
sentence_id = i[1] |
|||
sentence_one = [] |
|||
sentence_one.append(i[0]) |
|||
if sentence_one != []: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
return return_list |
|||
|
|||
|
|||
# def main(text:list): |
|||
# # text_list = paragraph_test(text) |
|||
# # batch_data = batch_data_process(text_list) |
|||
# # text_list = [] |
|||
# # for i in batch_data: |
|||
# # text_list.extend(i) |
|||
# # return_list = predict_data_post_processing(text_list) |
|||
# # return return_list |
|||
|
|||
def main(text: dict): |
|||
text_list = paragraph_test(text) |
|||
text_list_new = [] |
|||
for i in text_list: |
|||
pre = one_predict(i) |
|||
text_list_new.append(pre) |
|||
return_list = predict_data_post_processing(text_list_new) |
|||
return return_list |
|||
|
|||
|
|||
@app.route('/droprepeat/', methods=['POST']) |
|||
def sentence(): |
|||
print(request.remote_addr) |
|||
texts = request.json["texts"] |
|||
text_type = request.json["text_type"] |
|||
print("原始语句" + str(texts)) |
|||
# question = question.strip('。、!??') |
|||
|
|||
if isinstance(texts, dict): |
|||
texts_list = [] |
|||
y_pred_label_list = [] |
|||
position_list = [] |
|||
|
|||
# texts = texts.replace('\'', '\"') |
|||
if texts is None: |
|||
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} |
|||
return jsonify(return_text) |
|||
else: |
|||
assert text_type in ['focus', 'chapter'] |
|||
if text_type == 'focus': |
|||
texts_list = main(texts) |
|||
if text_type == 'chapter': |
|||
texts_list = main(texts) |
|||
return_text = {"texts": texts_list, "probabilities": None, "status_code": True} |
|||
else: |
|||
return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False} |
|||
return jsonify(return_text) |
|||
|
|||
|
|||
def classify(): # 调用模型,设置最大batch_size |
|||
while True: |
|||
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 |
|||
time.sleep(3) |
|||
continue |
|||
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text |
|||
data_dict_path = json.loads(query) |
|||
path = data_dict_path['path'] |
|||
# text_type = data_dict["text_type"] |
|||
|
|||
with open(path, encoding='utf8') as f1: |
|||
# 加载文件的对象 |
|||
data_dict = json.load(f1) |
|||
|
|||
query_id = data_dict['id'] |
|||
texts = data_dict["text"] |
|||
text_type = data_dict["text_type"] |
|||
|
|||
assert text_type in ['focus', 'chapter'] |
|||
if text_type == 'focus': |
|||
texts_list = main(texts) |
|||
elif text_type == 'chapter': |
|||
texts_list = main(texts) |
|||
else: |
|||
texts_list = [] |
|||
|
|||
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} |
|||
load_result_path = "./new_data_logs/{}.json".format(query_id) |
|||
|
|||
print("query_id: ", query_id) |
|||
print("load_result_path: ", load_result_path) |
|||
|
|||
with open(load_result_path, 'w', encoding='utf8') as f2: |
|||
# ensure_ascii=False才能输入中文,否则是Unicode字符 |
|||
# indent=2 JSON数据的缩进,美观 |
|||
json.dump(return_text, f2, ensure_ascii=False, indent=4) |
|||
debug_id_1 = 1 |
|||
redis_.set(query_id, load_result_path, 86400) |
|||
debug_id_2 = 2 |
|||
redis_.srem(db_key_querying, query_id) |
|||
debug_id_3 = 3 |
|||
log.log('start at', |
|||
'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( |
|||
query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) |
|||
|
|||
if __name__ == '__main__': |
|||
classify() |
|||
|
@ -0,0 +1,511 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/1/16 14:59 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
#! -*- coding: utf-8 -*- |
|||
|
|||
import os |
|||
from config.predict_t5_config import MultipleResultsDropT5Config |
|||
config = MultipleResultsDropT5Config() |
|||
os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id |
|||
import glob |
|||
from numpy import random |
|||
random.seed(1001) |
|||
from tqdm import tqdm |
|||
import numpy as np |
|||
import pandas as pd |
|||
import json |
|||
from tqdm import tqdm |
|||
from bert4keras.backend import keras, K |
|||
from bert4keras.layers import Loss |
|||
from bert4keras.models import build_transformer_model |
|||
from bert4keras.tokenizers import SpTokenizer |
|||
from bert4keras.optimizers import Adam |
|||
from bert4keras.snippets import sequence_padding, open |
|||
from bert4keras.snippets import DataGenerator, AutoRegressiveDecoder |
|||
from keras.models import Model |
|||
# from rouge import Rouge # pip install rouge |
|||
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction |
|||
import tensorflow as tf |
|||
from keras.backend import set_session |
|||
tfconfig = tf.ConfigProto() |
|||
tfconfig.gpu_options.allow_growth = True |
|||
set_session(tf.Session(config=tfconfig)) # 此处不同 |
|||
global graph |
|||
graph = tf.get_default_graph() |
|||
sess = tf.Session(graph=graph) |
|||
set_session(sess) |
|||
|
|||
# global graph,model |
|||
# graph = tf.get_default_graph() |
|||
# sess = tf.Session(graph=graph) |
|||
# K.set_session(sess) |
|||
|
|||
|
|||
# 基本参数 |
|||
|
|||
class GenerateModel(object): |
|||
def __init__(self): |
|||
|
|||
self.config_path = config.config_path |
|||
self.checkpoint_path = config.checkpoint_path |
|||
self.spm_path = config.spm_path |
|||
self.keep_tokens_path = config.keep_tokens_path |
|||
self.maxlen = config.maxlen |
|||
|
|||
def device_setup(self): |
|||
tokenizer = SpTokenizer(self.spm_path, token_start=None, token_end='</s>') |
|||
keep_tokens = json.load(open(self.keep_tokens_path)) |
|||
|
|||
t5 = build_transformer_model( |
|||
config_path=self.config_path, |
|||
checkpoint_path=self.checkpoint_path, |
|||
keep_tokens=keep_tokens, |
|||
model='mt5.1.1', |
|||
return_keras_model=False, |
|||
name='T5', |
|||
) |
|||
|
|||
# output = CrossEntropy(2)(model.inputs + model.outputs) |
|||
# |
|||
# model = Model(model.inputs, output) |
|||
encoder = t5.encoder |
|||
decoder = t5.decoder |
|||
model = t5.model |
|||
model.summary() |
|||
|
|||
output = CrossEntropy(1)([model.inputs[1], model.outputs[0]]) |
|||
|
|||
model = Model(model.inputs, output) |
|||
path_model = config.savemodel_path |
|||
model.load_weights(path_model) |
|||
|
|||
return encoder, decoder, model, tokenizer |
|||
|
|||
|
|||
class CrossEntropy(Loss): |
|||
"""交叉熵作为loss,并mask掉输入部分 |
|||
""" |
|||
|
|||
def compute_loss(self, inputs, mask=None): |
|||
y_true, y_pred = inputs |
|||
y_true = y_true[:, 1:] # 目标token_ids |
|||
y_mask = K.cast(mask[1], K.floatx())[:, 1:] # 解码器自带mask |
|||
y_pred = y_pred[:, :-1] # 预测序列,错开一位 |
|||
loss = K.sparse_categorical_crossentropy(y_true, y_pred) |
|||
loss = K.sum(loss * y_mask) / K.sum(y_mask) |
|||
return loss |
|||
|
|||
|
|||
class Beamdataone(object): |
|||
def __init__(self, num_beams, batch_id, text, end_id, minlen, min_ends, tokenizer, output_ids): |
|||
""" |
|||
Initialize n-best list of hypotheses. |
|||
""" |
|||
self.num_beams = num_beams |
|||
self.batch_id = batch_id |
|||
self.beams = [] |
|||
self.minlen = minlen |
|||
self.min_ends = min_ends |
|||
self.end_id = end_id |
|||
self.text = text |
|||
self.output_scores = np.zeros(1) |
|||
self.output_ids = [output_ids] |
|||
self.return_str = "" |
|||
self.over = False |
|||
self.tokenizer = tokenizer |
|||
# self.data() |
|||
self.output_str = "" |
|||
self.text_2_textids( |
|||
self.text |
|||
) |
|||
self.scores = np.zeros(1) |
|||
self.inputs_vector = 0 |
|||
|
|||
def text_2_textids(self,text): |
|||
token_ids, segment_ids = self.tokenizer.encode(text[0], maxlen=120) |
|||
self.text_ids = [token_ids] |
|||
|
|||
def add_data(self, step, output_scores): |
|||
''' |
|||
还存有的数据,直接可以被迭代, |
|||
@param text: |
|||
@return: |
|||
''' |
|||
# inputs = [np.array([i]) for i in inputs] |
|||
# output_ids, output_scores = self.first_output_ids, np.zeros(1) |
|||
# |
|||
# scores, states = self.predict( |
|||
# inputs, output_ids, states, temperature, 'logits' |
|||
# ) # 计算当前得分 |
|||
# if step == 0: # 第1步预测后将输入重复topk次 |
|||
# inputs = [np.repeat(i, self.num_beams, axis=0) for i in self.inputs] |
|||
# inputs = [self.token_ids, self.segment_ids] |
|||
# inputs = [np.array([i]) for i in inputs] |
|||
self.output_ids = np.array(self.output_ids) |
|||
if step == 0: # 第1步预测后将输入重复topk次 |
|||
self.text_ids = [np.repeat(i, self.num_beams, axis=0) for i in self.text_ids] |
|||
scores = output_scores.reshape((-1, 1)) + self.scores # 综合累积得分 |
|||
# scores = output_probas |
|||
scores = self.output_scores.reshape((-1, 1)) + scores # 综合累积得分 |
|||
indices = scores.argpartition(-self.num_beams, axis=None)[-self.num_beams:] # 仅保留topk |
|||
indices_1 = indices // scores.shape[1] # 行索引 |
|||
indices_2 = (indices % scores.shape[1]).reshape((-1, 1)) # 列索引 |
|||
self.output_ids = np.concatenate([self.output_ids[indices_1], indices_2], |
|||
1) # 更新输出 |
|||
self.output_scores = np.take_along_axis( |
|||
scores, indices, axis=None |
|||
) # 更新得分 |
|||
|
|||
is_end = self.output_ids[:, -1] == self.end_id # 标记是否以end标记结束 |
|||
self.end_counts = (self.output_ids == self.end_id).sum(1) # 统计出现的end标记 |
|||
if self.output_ids.shape[1] >= self.minlen: # 最短长度判断 |
|||
best = self.output_scores.argmax() # 得分最大的那个 |
|||
if is_end[best] and self.end_counts[best] >= self.min_ends: # 如果已经终止 |
|||
# return output_ids[best] # 直接输出 |
|||
self.return_str_main(self.output_ids, best) |
|||
self.over = True |
|||
else: # 否则,只保留未完成部分 |
|||
flag = ~is_end | (self.end_counts < self.min_ends) # 标记未完成序列 |
|||
if not flag.all(): # 如果有已完成的 |
|||
self.output_ids = self.output_ids[flag] # 扔掉已完成序列 |
|||
self.output_scores = self.output_scores[flag] # 扔掉已完成序列 |
|||
self.end_counts = self.end_counts[flag] # 扔掉已完成end计数 |
|||
self.num_beams = flag.sum() # topk相应变化 |
|||
self.output_ids = self.output_ids.tolist() |
|||
self.output_str = [tokenizer.decode(ids) for ids in self.output_ids] |
|||
self.text_ids = [self.text_ids[0] for i in range(len(self.output_ids))] |
|||
|
|||
|
|||
# # 达到长度直接输出 |
|||
# return output_ids[output_scores.argmax()] |
|||
|
|||
|
|||
# def data(self): |
|||
# token_ids, segment_ids = self.tokenizer.encode(self.text, maxlen=256) |
|||
# self.token_ids = token_ids |
|||
# self.segment_ids = segment_ids |
|||
|
|||
|
|||
# input_str = [text for i in range(self.num_beams)] |
|||
# output_str = self.output_str |
|||
# return input_str, output_str |
|||
|
|||
def return_str_main(self, output_ids, best): |
|||
output_ids_best = output_ids[best] |
|||
self.return_str = self.tokenizer.decode(output_ids_best) |
|||
|
|||
|
|||
class AutoTitle(AutoRegressiveDecoder): |
|||
"""seq2seq解码器 |
|||
""" |
|||
def __init__(self, encoder, decoder, model, tokenizer, start_id, end_id, maxlen, minlen=1): |
|||
super(AutoTitle, self).__init__(start_id, end_id, maxlen, minlen) |
|||
self.encoder = encoder |
|||
self.decoder = decoder |
|||
self.model = model |
|||
self.tokenizer = tokenizer |
|||
self.start_id = start_id |
|||
self.end_id = end_id |
|||
self.minlen = minlen |
|||
self.models = {} |
|||
if start_id is None: |
|||
self.first_output_ids = np.empty((1, 0), dtype=int) |
|||
else: |
|||
self.first_output_ids = np.array([[self.start_id]]) |
|||
|
|||
# @AutoRegressiveDecoder.wraps(default_rtype='probas') |
|||
# def predict(self, inputs, output_ids, states): |
|||
# token_ids, segment_ids = inputs |
|||
# token_ids = np.concatenate([token_ids, output_ids], 1) |
|||
# segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) |
|||
# with graph.as_default(): |
|||
# K.set_session(sess) |
|||
# nodes = self.last_token(self.model).predict([token_ids, segment_ids]) |
|||
# return nodes |
|||
# # return self.last_token(self.model).predict([token_ids, segment_ids]) |
|||
|
|||
# @AutoRegressiveDecoder.wraps(default_rtype='probas') |
|||
# def predict(self, inputs, output_ids, states): |
|||
# c_encoded = inputs[0] |
|||
# with graph.as_default(): |
|||
# K.set_session(sess) |
|||
# nodes = self.last_token(self.decoder).predict([c_encoded, output_ids]) |
|||
# return nodes |
|||
|
|||
@AutoRegressiveDecoder.wraps(default_rtype='probas') |
|||
def predict(self, inputs, output_ids, states): |
|||
c_encoded = inputs[0] |
|||
with graph.as_default(): |
|||
K.set_session(sess) |
|||
nodes = self.last_token(decoder).predict([c_encoded, output_ids]) |
|||
return nodes |
|||
|
|||
def predict_batch(self, inputs): |
|||
# inputs, output_ids, states, temperature, 'probas' |
|||
token_ids, output_ids = inputs |
|||
# token_ids = np.concatenate([token_ids, output_ids], 1) |
|||
# segment_ids = np.concatenate([segment_ids, np.ones_like(output_ids)], 1) |
|||
with graph.as_default(): |
|||
K.set_session(sess) |
|||
nodes = self.decoder.predict([token_ids, output_ids]) |
|||
return nodes |
|||
|
|||
def data_generator(self, token_ids, output_ids): |
|||
|
|||
batch_token_ids = [] |
|||
for i,j in zip(token_ids, output_ids): |
|||
|
|||
batch_token_ids = sequence_padding(token_ids) |
|||
batch_segment_ids = sequence_padding(output_ids) |
|||
return batch_token_ids, batch_segment_ids |
|||
|
|||
def beam_search_batch( |
|||
self, |
|||
inputs_str, |
|||
states=None, |
|||
temperature=1, |
|||
min_ends=1, |
|||
num_beam=3 |
|||
): |
|||
"""随机采样n个结果 |
|||
说明:非None的topk表示每一步只从概率最高的topk个中采样;而非None的topp |
|||
表示每一步只从概率最高的且概率之和刚好达到topp的若干个token中采样。 |
|||
返回:n个解码序列组成的list。 |
|||
""" |
|||
output_str = [] |
|||
# token_ids, segment_ids = self.data_generator(inputs, output_ids) |
|||
batch_nums = len(inputs_str) |
|||
return_str_batch = [0] * batch_nums |
|||
# output_ids = np.empty((batch_nums, 0), dtype=int) |
|||
output_ids = np.array([self.start_id]) |
|||
generated = [Beamdataone(num_beam, i, [inputs_str[i]], self.end_id, self.minlen, min_ends, self.tokenizer, output_ids) for i in range(batch_nums)] |
|||
# index_data = [i for i in range(batch_nums)] |
|||
|
|||
c_token_ids = [] |
|||
for i in generated: |
|||
text_ids = i.text_ids |
|||
c_token_ids.extend(text_ids) |
|||
c_token_ids = sequence_padding(c_token_ids) |
|||
c_encoded = encoder.predict(np.array(c_token_ids)) |
|||
|
|||
# probas_bool = np.array(token_ids, dtype=bool) |
|||
# # np.array(np.where(probas_bool == True)) |
|||
# for i, sentence in enumerate(probas_bool): |
|||
# lie = np.array(np.where(sentence == True))[0] |
|||
# probas_new.append(probas[i, lie[-1]]) |
|||
|
|||
for i in range(len(generated)): |
|||
probas_bool = np.array(generated[i].text_ids[0], dtype=bool) |
|||
lie = np.array(np.where(probas_bool == True))[0] |
|||
# c_encoded_dan = c_encoded[i, lie[-1]] |
|||
c_encoded_dan = c_encoded[np.ix_([i], lie)] |
|||
generated[i].inputs_vector = c_encoded_dan[0] |
|||
|
|||
|
|||
for step in range(self.maxlen): |
|||
# if step == 0: |
|||
# token_ids, segment_ids = self.data_generator(inputs_str, output_str) |
|||
# else: |
|||
# inputs_str, output_str = [], [] |
|||
inputs_vector_batch, output_ids_batch = [], [] |
|||
batch_input_num_beam_num = [] |
|||
for i in generated: |
|||
inputs_vector = i.inputs_vector |
|||
# if step != 0: |
|||
# output_ids_batch.extend(i.output_ids) |
|||
# text_ids_batch.extend(text_ids) |
|||
# else: |
|||
inputs_vector_batch.append(inputs_vector) |
|||
output_ids_batch.extend(i.output_ids) |
|||
if step != 0: |
|||
batch_input_num_beam_num.append(i.num_beams) |
|||
|
|||
# token_ids, output_ids_batch = self.data_generator(inputs_vector_batch, output_ids_batch) |
|||
|
|||
# token_ids_batch = sequence_padding(token_ids_batch) |
|||
# segment_ids_batch = sequence_padding(segment_ids_batch) |
|||
# output_ids_batch = np.array(output_ids_batch) |
|||
# if step == 0: |
|||
|
|||
inputs = [inputs_vector_batch, output_ids_batch] |
|||
|
|||
probas = self.predict_batch( |
|||
inputs |
|||
) # 计算当前概率 |
|||
|
|||
probas_new = [] |
|||
probas_bool = np.array(inputs_vector_batch, dtype=bool) |
|||
# np.array(np.where(probas_bool == True)) |
|||
for i, sentence in enumerate(probas_bool): |
|||
lie = np.array(np.where(sentence == True))[0] |
|||
probas_new.append(probas[i, lie[-1]]) |
|||
probas = np.array(probas_new) |
|||
|
|||
|
|||
if step != 0: |
|||
num = 0 |
|||
if len(generated) > 1: |
|||
index = 0 |
|||
for index in range(len(batch_input_num_beam_num)-1): |
|||
cc = num |
|||
num += batch_input_num_beam_num[index] |
|||
generated[index].add_data(step, probas[cc:num,:]) |
|||
generated[index+1].add_data(step, probas[num:,:]) |
|||
else: |
|||
generated[0].add_data(step, probas[:,:]) |
|||
|
|||
else: |
|||
for index in range(len(generated)): |
|||
generated[index].add_data(step, probas[index,:]) |
|||
# i = 0 |
|||
# while True: |
|||
# bool_ = generated[i].over |
|||
# if bool_ == True: |
|||
# one_sentence = generated.pop(i) |
|||
# return_str_batch[i] = one_sentence.return_str |
|||
# if i > len(generated) - 1: |
|||
# break |
|||
# else: |
|||
# i += 1 |
|||
# if i > len(generated) - 1: |
|||
# break |
|||
|
|||
generated_new = [] |
|||
for i in range(len(generated)): |
|||
bool_ = generated[i].over |
|||
if bool_ == False: |
|||
generated_new.append(generated[i]) |
|||
else: |
|||
return_str_batch[generated[i].batch_id] = generated[i].return_str |
|||
generated = generated_new |
|||
|
|||
|
|||
if generated == []: |
|||
return return_str_batch |
|||
return return_str_batch |
|||
|
|||
|
|||
def generate(self, text, topk=5): |
|||
c_token_ids, _ = tokenizer.encode(text, maxlen=120) |
|||
with graph.as_default(): |
|||
K.set_session(sess) |
|||
c_encoded = encoder.predict(np.array([c_token_ids]))[0] |
|||
output_ids = self.beam_search([c_encoded], topk=topk) # 基于beam search |
|||
return_text = tokenizer.decode([int(i) for i in output_ids]) |
|||
return_text = return_text.replace(",", ",") |
|||
return return_text |
|||
|
|||
def generate_random(self, text, n=30, topp=0.9): |
|||
c_token_ids, _ = self.tokenizer.encode(text, maxlen=120) |
|||
with graph.as_default(): |
|||
K.set_session(sess) |
|||
c_encoded = self.encoder.predict(np.array([c_token_ids]))[0] |
|||
output_ids = self.random_sample([c_encoded], n, topp=topp) # 基于随机采样 |
|||
text = [] |
|||
for ids in output_ids: |
|||
text.append(tokenizer.decode([int(i) for i in ids])) |
|||
return text |
|||
|
|||
def generate_beam_search_batch(self, text): |
|||
output_str = self.beam_search_batch(text) # 基于随机采样 |
|||
return output_str |
|||
|
|||
|
|||
generatemodel = GenerateModel() |
|||
encoder, decoder, model, tokenizer = generatemodel.device_setup() |
|||
autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) |
|||
|
|||
|
|||
|
|||
|
|||
def just_show_sentence(file): |
|||
""" |
|||
@param file:list |
|||
""" |
|||
text = file[0] |
|||
pre = autotitle.generate(text) |
|||
return pre |
|||
|
|||
|
|||
def just_show_sentence_batch(file: list) -> object: |
|||
text = file |
|||
pre = autotitle.generate_beam_search_batch(text) |
|||
return pre |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
# file = "train_2842.txt" |
|||
# just_show(file) |
|||
# text = ["历史和当下都证明,创新是民族生存、发展的不竭源泉,是自身发展的必然选择,是时代对于青年们的深切呼唤"] |
|||
# a = just_show_sentence(text) |
|||
# print(a) |
|||
# print(type(a)) |
|||
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|||
# is_novel = False |
|||
# path = "./data/700条论文测试.xlsx" |
|||
# df_list = pd.read_excel(path).values.tolist() |
|||
# |
|||
# |
|||
# df_list_new = [] |
|||
# print(len(df_list)) |
|||
# for i in tqdm(df_list): |
|||
# pre = just_show_sentence([i[0]]) |
|||
# |
|||
# df_list_new.append([i[0], i[1], pre]) |
|||
# |
|||
# df = pd.DataFrame(df_list_new, columns=["原文", "yy降重", "t5模型"]) |
|||
# df.to_excel("./data/700条论文测试_7.xlsx", index=None) |
|||
|
|||
# ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|||
|
|||
import os |
|||
|
|||
file = "./data/11篇汇总txt_new.txt" |
|||
file_t5 = "./data/11篇汇总txt_new_predict_t5.txt" |
|||
file_t5_0724 = "./data/11篇汇总txt_new_predict_t5_0724.txt" |
|||
|
|||
try: |
|||
with open(file, 'r', encoding="utf-8") as f: |
|||
lines = [x.strip() for x in f if x.strip() != ''] |
|||
except: |
|||
with open(file, 'r', encoding="gbk") as f: |
|||
lines = [x.strip() for x in f if x.strip() != ''] |
|||
|
|||
zishu = 0 |
|||
data = [] |
|||
for i in tqdm(lines): |
|||
|
|||
zishu += len(i) |
|||
pre = just_show_sentence([i]) |
|||
data.append([i, pre]) |
|||
|
|||
with open(file_t5_0724, "w", encoding='utf-8') as file: |
|||
for i in data: |
|||
file.write("\t".join(i) + '\n') |
|||
file.close() |
|||
print(zishu) |
|||
|
|||
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ |
|||
|
|||
# text = ["'李正旺你真是傻逼讪笑,挥手道:“不不不,你千万别误会", |
|||
# "历史和当下都证明,创新是民族生存、“发展的不竭源泉”,是是自身发展的必然选择", |
|||
# "自身发展的必然选择", |
|||
# "强调轻资产经营, 更加重视经营风险的规避", |
|||
# "历史和当下都证明,创新是民族生存、发展的不竭源泉,是是自身发展的必然选择", |
|||
# "是时代对于青年们的深切呼唤"] |
|||
# text = ["随着经济的发展,人们生活水平的提高,环境问题也日益突出。", |
|||
# "环境问题中的化学污染是影响我国居民生活质量不可忽视的重要因素,而仪器分析作为化工专业课程中必不可少的一门课程也不例外。", |
|||
# "所以对学生对应用仪器分析解决实际问题的能力要求很高。", |
|||
# "随着经济的发展,人们生活水平的提高,环境问题也日益突出。"] |
|||
# print(just_show_sentence(text)) |
|||
# print(just_show_sentence_top(text)) |
|||
# print(just_show_chachong_random(text)) |
|||
|
|||
# print(tokenizer.encode("\"", maxlen=120)) |
|||
# print(just_show_sentence_batch(text)) |
@ -1 +0,0 @@ |
|||
nohup python predict_flask.py > myout.file 2>&1 & |
@ -1 +0,0 @@ |
|||
gunicorn flask_predict_no_batch_t5:app -c gunicorn_config.py |
@ -0,0 +1,9 @@ |
|||
import redis |
|||
|
|||
pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=2) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
|
|||
api_key_list_ip_1 = "api_key_192.168.1.17" |
|||
for i in range(10): |
|||
redis_.rpush(api_key_list_ip_1, i) |
@ -0,0 +1,37 @@ |
|||
import requests |
|||
|
|||
|
|||
def dialog_line_parse(url, text): |
|||
""" |
|||
将数据输入模型进行分析并输出结果 |
|||
:param url: 模型url |
|||
:param text: 进入模型的数据 |
|||
:return: 模型返回结果 |
|||
""" |
|||
|
|||
response = requests.post( |
|||
url, |
|||
json=text, |
|||
timeout=100000 |
|||
) |
|||
if response.status_code == 200: |
|||
return response.json() |
|||
else: |
|||
# logger.error( |
|||
# "【{}】 Failed to get a proper response from remote " |
|||
# "server. Status Code: {}. Response: {}" |
|||
# "".format(url, response.status_code, response.text) |
|||
# ) |
|||
print("【{}】 Failed to get a proper response from remote " |
|||
"server. Status Code: {}. Response: {}" |
|||
"".format(url, response.status_code, response.text)) |
|||
print(text) |
|||
return {} |
|||
|
|||
|
|||
with open("data/drop_weight_data.txt", encoding="utf-8") as f: |
|||
text_list = [i for i in f.read().split("\n")] |
|||
for i in text_list[:-1]: |
|||
text = dialog_line_parse("http://192.168.31.74:19000", {"texts": f"改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{i}"}) |
|||
print("原文",i) |
|||
print("模型预测", text) |
Loading…
Reference in new issue