import os from config.predict_t5_config import MultipleResultsDropT5Config t5config = MultipleResultsDropT5Config() from config.predict_sim_config import DropSimBertConfig simbertconfig = DropSimBertConfig() os.environ["CUDA_VISIBLE_DEVICES"] = t5config.cuda_id from flask import Flask, jsonify from flask import request from predict_t5 import (GenerateModel as T5GenerateModel, AutoTitle as T5AutoTitle) from predict_sim import (GenerateModel as SimBertGenerateModel, AutoTitle as SimBertT5AutoTitle) import json from threading import Thread import time import re import requests app = Flask(__name__) app.config["JSON_AS_ASCII"] = False import logging pattern = r"[。]" RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") fuhao_end_sentence = ["。",",","?","!","…"] t5generatemodel = T5GenerateModel(t5config.config_path, t5config.checkpoint_path, t5config.spm_path, t5config.keep_tokens_path, t5config.maxlen, t5config.savemodel_path) encoder, decoder, model, tokenizer = t5generatemodel.device_setup() t5autotitle = T5AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) simbertgeneratemodel = SimBertGenerateModel(simbertconfig.config_path, simbertconfig.checkpoint_path, simbertconfig.dict_path, simbertconfig.maxlen, simbertconfig.savemodel_path) encoder, seq2seq, tokenizer = simbertgeneratemodel.device_setup() simbertautotitle = SimBertT5AutoTitle(seq2seq, tokenizer, start_id=None, end_id=tokenizer._token_end_id, maxlen=120) def requests_chatGPT(data): res = requests.post('http://98.142.138.229:9999/chatgpt', data=data) return res.json()['res'] def get_dialogs_index(line: str): """ 获取对话及其索引 :param line 文本 :return dialogs 对话内容 dialogs_index: 对话位置索引 other_index: 其他内容位置索引 """ dialogs = re.finditer(RE_DIALOG, line) dialogs_text = re.findall(RE_DIALOG, line) dialogs_index = [] for dialog in dialogs: all_ = [i for i in range(dialog.start(), dialog.end())] dialogs_index.extend(all_) other_index = [i for i in range(len(line)) if i not in dialogs_index] return dialogs_text, dialogs_index, other_index def chulichangju_1(text, chulipangban_return_list, short_num): fuhao = [",","?","!","…"] dialogs_text, dialogs_index, other_index = get_dialogs_index(text) text_1 = text[:120] text_2 = text[120:] text_1_new = "" if text_2 == "": chulipangban_return_list.append([text_1, short_num]) return chulipangban_return_list for i in range(len(text_1)-1, -1, -1): if text_1[i] in fuhao: if i in dialogs_index: continue text_1_new = text_1[:i] text_1_new += text_1[i] chulipangban_return_list.append([text_1_new, short_num]) if text_2 != "": if i+1 != 120: text_2 = text_1[i+1:] + text_2 break # else: # chulipangban_return_list.append(text_1) if text_1_new == "": chulipangban_return_list.append([text_1, short_num]) if text_2 != "": short_num += 1 chulipangban_return_list = chulichangju_1(text_2, chulipangban_return_list, short_num) return chulipangban_return_list def chulipangban_test_1(text): # 引号处理 dialogs_text, dialogs_index, other_index = get_dialogs_index(text) for dialogs_text_dan in dialogs_text: text_dan_list = text.split(dialogs_text_dan) if "。" in dialogs_text_dan: dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") text = dialogs_text_dan.join(text_dan_list) # text_new_str = "".join(text_new) sentence_list = text.split("。") # sentence_list_new = [] # for i in sentence_list: # if i != "": # sentence_list_new.append(i) # sentence_list = sentence_list_new sentence_batch_list = [] sentence_batch_one = [] sentence_batch_length = 0 return_list = [] for sentence in sentence_list: if len(sentence) < 120: sentence_batch_length += len(sentence) sentence_batch_list.append([sentence, 0]) # sentence_pre = autotitle.gen_synonyms_short(sentence) # return_list.append(sentence_pre) else: sentence_split_list = chulichangju_1(sentence,[], 0) for sentence_short in sentence_split_list: sentence_batch_list.append(sentence_short) return sentence_batch_list def paragraph_test(texts:str): text_list = chulipangban_test_1(texts) # text_new_str = "".join(text_new) return text_list def batch_data_process(text_list): sentence_batch_length = 0 sentence_batch_one = [] sentence_batch_list = [] for sentence in text_list: sentence_batch_length += len(sentence[0]) sentence_batch_one.append(sentence) if sentence_batch_length > 500: sentence_batch_length = 0 sentence_ = sentence_batch_one.pop(-1) sentence_batch_list.append(sentence_batch_one) sentence_batch_one = [] sentence_batch_one.append(sentence_) sentence_batch_list.append(sentence_batch_one) return sentence_batch_list def batch_predict(batch_data_list): ''' 一个bacth数据预测 @param data_text: @return: ''' batch_data_list_new = [] batch_data_text_list = [] batch_data_snetence_id_list = [] for i in batch_data_list: batch_data_text_list.append(i[0]) batch_data_snetence_id_list.append(i[1:]) # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) batch_pre_data_list = batch_data_text_list for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): batch_data_list_new.append([text] + sentence_id) return batch_data_list_new def one_predict(data_text): ''' 一个条数据预测 @param data_text: @return: ''' return_data_list = [] if data_text[0] != "": data_inputs = data_text[0].replace("&", "。") prompt_list = ["请帮我改写一下这个句子", "请帮美化一下下面句子", "请帮我修改下面句子让这句话更完美"] pre_data_list = [] for i in prompt_list: pre_data = requests_chatGPT( data={ 'prompt':i, 'text':data_inputs } ) pre_data_list.append(pre_data) modelclass_list = [t5autotitle, simbertautotitle] for model in modelclass_list: pre_data_list.append(model.generate(data_inputs)) else: pre_data_list = [""] * 5 for pre_data in pre_data_list: return_data_list.append([pre_data] + data_text[1:]) return return_data_list def predict_data_post_processing(text_list, index): text_list_sentence = [] # text_list_sentence.append([text_list[0][0], text_list[0][1]]) for i in range(len(text_list)): if text_list[i][index][2] != 0: text_list_sentence[-1][0] += text_list[i][index][0] else: text_list_sentence.append([text_list[i][0], text_list[i][1]]) return_list = {} sentence_one = [] sentence_id = text_list_sentence[0][1] for i in text_list_sentence: if i[1] == sentence_id: sentence_one.append(i[0]) else: return_list[sentence_id] = "。".join(sentence_one) sentence_id = i[1] sentence_one = [] sentence_one.append(i[0]) if sentence_one != []: return_list[sentence_id] = "。".join(sentence_one) return return_list # def main(text:list): # # text_list = paragraph_test(text) # # batch_data = batch_data_process(text_list) # # text_list = [] # # for i in batch_data: # # text_list.extend(i) # # return_list = predict_data_post_processing(text_list) # # return return_list def main(text: str): text_list = paragraph_test(text) text_list_new = [] return_list = [] for i in text_list: pre_list = one_predict(i) text_list_new.append(pre_list) for index in range(len(text_list_new[0])): return_list.append(predict_data_post_processing(text_list_new, index)) return return_list @app.route('/multiple_results_droprepeat/', methods=['POST']) def sentence(): print(request.remote_addr) texts = request.json["texts"] print("原始语句" + str(texts)) # question = question.strip('。、!??') if isinstance(texts, str): texts_list = [] y_pred_label_list = [] position_list = [] # texts = texts.replace('\'', '\"') if texts is None: return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} return jsonify(return_text) else: texts_list = main(texts) return_text = {"texts": texts_list, "probabilities": None, "status_code": True} else: return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} return jsonify(return_text) if __name__ == "__main__": fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') logging.basicConfig( handlers=[fh], level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%a, %d %b %Y %H:%M:%S', ) app.run(host="0.0.0.0", port=14000, threaded=True, debug=False)