import os # os.environ["TF_KERAS"] = "1" # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = "-1" from flask import Flask, jsonify from flask import request # from linshi import autotitle import requests from flask import request from predict_t5 import autotitle import redis import uuid import json from threading import Thread import time import re pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1) redis_ = redis.Redis(connection_pool=pool, decode_responses=True) db_key_query = 'query' db_key_result = 'result' batch_size = 32 app = Flask(__name__) app.config["JSON_AS_ASCII"] = False import logging pattern = r"[。]" RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") fuhao_end_sentence = ["。",",","?","!","…"] config = { "batch_szie": 1000 } def get_dialogs_index(line: str): """ 获取对话及其索引 :param line 文本 :return dialogs 对话内容 dialogs_index: 对话位置索引 other_index: 其他内容位置索引 """ dialogs = re.finditer(RE_DIALOG, line) dialogs_text = re.findall(RE_DIALOG, line) dialogs_index = [] for dialog in dialogs: all_ = [i for i in range(dialog.start(), dialog.end())] dialogs_index.extend(all_) other_index = [i for i in range(len(line)) if i not in dialogs_index] return dialogs_text, dialogs_index, other_index def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): fuhao = [",","?","!","…"] dialogs_text, dialogs_index, other_index = get_dialogs_index(text) print(len(text)) text_1 = text[:120] text_2 = text[120:] text_1_new = "" if text_2 == "": chulipangban_return_list.append([text_1, snetence_id, short_num]) return chulipangban_return_list for i in range(len(text_1)-1, -1, -1): if text_1[i] in fuhao: if i in dialogs_index: continue text_1_new = text_1[:i] text_1_new += text_1[i] chulipangban_return_list.append([text_1_new, snetence_id, short_num]) if text_2 != "": if i+1 != 120: text_2 = text_1[i+1:] + text_2 break # else: # chulipangban_return_list.append(text_1) if text_1_new == "": chulipangban_return_list.append([text_1, snetence_id, short_num]) if text_2 != "": short_num += 1 chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) return chulipangban_return_list def chulipangban_test_1(snetence_id, text): # 引号处理 dialogs_text, dialogs_index, other_index = get_dialogs_index(text) for dialogs_text_dan in dialogs_text: text_dan_list = text.split(dialogs_text_dan) if "。" in dialogs_text_dan: dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") text = dialogs_text_dan.join(text_dan_list) # text_new_str = "".join(text_new) sentence_list = text.split("。") # sentence_list_new = [] # for i in sentence_list: # if i != "": # sentence_list_new.append(i) # sentence_list = sentence_list_new sentence_batch_list = [] sentence_batch_one = [] sentence_batch_length = 0 return_list = [] for sentence in sentence_list: if len(sentence) < 120: sentence_batch_length += len(sentence) sentence_batch_list.append([sentence, snetence_id, 0]) # sentence_pre = autotitle.gen_synonyms_short(sentence) # return_list.append(sentence_pre) else: sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) for sentence_short in sentence_split_list: sentence_batch_list.append(sentence_short) return sentence_batch_list def paragraph_test_(text:list, text_new:list): for i in range(len(text)): text = chulipangban_test_1(text, i) text = "。".join(text) text_new.append(text) # text_new_str = "".join(text_new) return text_new def paragraph_test(texts:dict): text_new = [] for i, text in texts.items(): text_list = chulipangban_test_1(i, text) text_new.extend(text_list) # text_new_str = "".join(text_new) return text_new def batch_data_process(text_list): sentence_batch_length = 0 sentence_batch_one = [] sentence_batch_list = [] for sentence in text_list: sentence_batch_length += len(sentence[0]) sentence_batch_one.append(sentence) if sentence_batch_length > 500: sentence_batch_length = 0 sentence_ = sentence_batch_one.pop(-1) sentence_batch_list.append(sentence_batch_one) sentence_batch_one = [] sentence_batch_one.append(sentence_) sentence_batch_list.append(sentence_batch_one) return sentence_batch_list def batch_predict(batch_data_list): ''' 一个bacth数据预测 @param data_text: @return: ''' batch_data_list_new = [] batch_data_text_list = [] batch_data_snetence_id_list = [] for i in batch_data_list: batch_data_text_list.append(i[0]) batch_data_snetence_id_list.append(i[1:]) # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) batch_pre_data_list = batch_data_text_list for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): batch_data_list_new.append([text] + sentence_id) return batch_data_list_new def one_predict(data_text): ''' 一个条数据预测 @param data_text: @return: ''' if data_text[0] != "": data_inputs = data_text[0].replace("&", "。") pre_data = autotitle.generate(data_inputs) else: pre_data = "" data_new = [pre_data] + data_text[1:] return data_new def predict_data_post_processing(text_list): text_list_sentence = [] # text_list_sentence.append([text_list[0][0], text_list[0][1]]) for i in range(len(text_list)): if text_list[i][2] != 0: text_list_sentence[-1][0] += text_list[i][0] else: text_list_sentence.append([text_list[i][0], text_list[i][1]]) return_list = {} sentence_one = [] sentence_id = "0" for i in text_list_sentence: if i[1] == sentence_id: sentence_one.append(i[0]) else: return_list[sentence_id] = "。".join(sentence_one) sentence_id = i[1] sentence_one = [] sentence_one.append(i[0]) if sentence_one != []: return_list[sentence_id] = "。".join(sentence_one) return return_list # def main(text:list): # # text_list = paragraph_test(text) # # batch_data = batch_data_process(text_list) # # text_list = [] # # for i in batch_data: # # text_list.extend(i) # # return_list = predict_data_post_processing(text_list) # # return return_list def main(text: dict): text_list = paragraph_test(text) text_list_new = [] for i in text_list: pre = one_predict(i) text_list_new.append(pre) return_list = predict_data_post_processing(text_list_new) return return_list @app.route('/droprepeat/', methods=['POST']) def sentence(): print(request.remote_addr) texts = request.json["texts"] text_type = request.json["text_type"] print("原始语句" + str(texts)) # question = question.strip('。、!??') if isinstance(texts, dict): texts_list = [] y_pred_label_list = [] position_list = [] # texts = texts.replace('\'', '\"') if texts is None: return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} return jsonify(return_text) else: assert text_type in ['focus', 'chapter'] if text_type == 'focus': texts_list = main(texts) if text_type == 'chapter': texts_list = main(texts) return_text = {"texts": texts_list, "probabilities": None, "status_code": True} else: return_text = {"texts":"输入格式应该为list", "probabilities": None, "status_code":False} return jsonify(return_text) def classify(): # 调用模型,设置最大batch_size while True: if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 continue query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text data_dict = json.loads(query) query_id = data_dict['id'] texts = data_dict['text'] # 拼接若干text 为batch text_type = data_dict["text_type"] if isinstance(texts, dict): texts_list = [] if texts is None: return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402} else: assert text_type in ['focus', 'chapter'] if text_type == 'focus': texts_list = main(texts) if text_type == 'chapter': texts_list = main(texts) return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} else: return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401} redis_.set(query_id, json.dumps(return_text, ensure_ascii=False)) @app.route("/predict", methods=["POST"]) def handle_query(): print(request.remote_addr) texts = request.json["texts"] text_type = request.json["text_type"] id_ = str(uuid.uuid1()) # 为query生成唯一标识 d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id redis_.rpush(db_key_query, json.dumps(d, ensure_ascii=False)) # 加入redis result_text = d return jsonify(result_text) # 返回结果 t = Thread(target=classify) t.start() if __name__ == "__main__": fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') logging.basicConfig( handlers=[fh], level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%a, %d %b %Y %H:%M:%S', ) app.run(host="0.0.0.0", port=14000, threaded=True, debug=False)