import os from config.predict_t5_config import DropT5Config config = DropT5Config() os.environ["CUDA_VISIBLE_DEVICES"] = config.cuda_id from flask import Flask, jsonify from flask import request # from linshi import autotitle import requests from predict_t5 import GenerateModel, AutoTitle import redis import uuid import json from threading import Thread import time import re import logging logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=6, password="zhicheng123*") redis_ = redis.Redis(connection_pool=pool, decode_responses=True) db_key_query = 'query' db_key_querying = 'querying' batch_size = 32 app = Flask(__name__) app.config["JSON_AS_ASCII"] = False import logging pattern = r"[。]" RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") fuhao_end_sentence = ["。", ",", "?", "!", "…"] generatemodel = GenerateModel(config.config_path, config.checkpoint_path, config.spm_path, config.keep_tokens_path, config.maxlen, config.savemodel_path) encoder, decoder, model, tokenizer = generatemodel.device_setup() autotitle = AutoTitle(encoder, decoder, model, tokenizer, start_id=0, end_id=tokenizer._token_end_id, maxlen=120) class log: def __init__(self): pass def log(*args, **kwargs): format = '%Y/%m/%d-%H:%M:%S' format_h = '%Y-%m-%d' value = time.localtime(int(time.time())) dt = time.strftime(format, value) dt_log_file = time.strftime(format_h, value) log_file = 'log_file/access-%s' % dt_log_file + ".log" if not os.path.exists(log_file): with open(os.path.join(log_file), 'w', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) else: with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) def get_dialogs_index(line: str): """ 获取对话及其索引 :param line 文本 :return dialogs 对话内容 dialogs_index: 对话位置索引 other_index: 其他内容位置索引 """ dialogs = re.finditer(RE_DIALOG, line) dialogs_text = re.findall(RE_DIALOG, line) dialogs_index = [] for dialog in dialogs: all_ = [i for i in range(dialog.start(), dialog.end())] dialogs_index.extend(all_) other_index = [i for i in range(len(line)) if i not in dialogs_index] return dialogs_text, dialogs_index, other_index def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): fuhao = [",", "?", "!", "…"] dialogs_text, dialogs_index, other_index = get_dialogs_index(text) text_1 = text[:120] text_2 = text[120:] text_1_new = "" if text_2 == "": chulipangban_return_list.append([text_1, snetence_id, short_num]) return chulipangban_return_list for i in range(len(text_1) - 1, -1, -1): if text_1[i] in fuhao: if i in dialogs_index: continue text_1_new = text_1[:i] text_1_new += text_1[i] chulipangban_return_list.append([text_1_new, snetence_id, short_num]) if text_2 != "": if i + 1 != 120: text_2 = text_1[i + 1:] + text_2 break # else: # chulipangban_return_list.append(text_1) if text_1_new == "": chulipangban_return_list.append([text_1, snetence_id, short_num]) if text_2 != "": short_num += 1 chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) return chulipangban_return_list def chulipangban_test_1(snetence_id, text): # 引号处理 dialogs_text, dialogs_index, other_index = get_dialogs_index(text) for dialogs_text_dan in dialogs_text: text_dan_list = text.split(dialogs_text_dan) if "。" in dialogs_text_dan: dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") text = dialogs_text_dan.join(text_dan_list) # text_new_str = "".join(text_new) sentence_list = text.split("。") # sentence_list_new = [] # for i in sentence_list: # if i != "": # sentence_list_new.append(i) # sentence_list = sentence_list_new sentence_batch_list = [] sentence_batch_one = [] sentence_batch_length = 0 return_list = [] for sentence in sentence_list: if len(sentence) < 120: sentence_batch_length += len(sentence) sentence_batch_list.append([sentence, snetence_id, 0]) # sentence_pre = autotitle.gen_synonyms_short(sentence) # return_list.append(sentence_pre) else: sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) for sentence_short in sentence_split_list: sentence_batch_list.append(sentence_short) return sentence_batch_list def paragraph_test(texts: dict): text_new = [] for i, text in texts.items(): text_list = chulipangban_test_1(i, text) text_new.extend(text_list) # text_new_str = "".join(text_new) return text_new def batch_data_process(text_list): sentence_batch_length = 0 sentence_batch_one = [] sentence_batch_list = [] for sentence in text_list: sentence_batch_length += len(sentence[0]) sentence_batch_one.append(sentence) if sentence_batch_length > 500: sentence_batch_length = 0 sentence_ = sentence_batch_one.pop(-1) sentence_batch_list.append(sentence_batch_one) sentence_batch_one = [] sentence_batch_one.append(sentence_) sentence_batch_list.append(sentence_batch_one) return sentence_batch_list def batch_predict(batch_data_list): ''' 一个bacth数据预测 @param data_text: @return: ''' batch_data_list_new = [] batch_data_text_list = [] batch_data_snetence_id_list = [] for i in batch_data_list: batch_data_text_list.append(i[0]) batch_data_snetence_id_list.append(i[1:]) # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) batch_pre_data_list = batch_data_text_list for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): batch_data_list_new.append([text] + sentence_id) return batch_data_list_new def one_predict(data_text): ''' 一个条数据预测 @param data_text: @return: ''' if data_text[0] != "": data_inputs = data_text[0].replace("&", "。") pre_data = autotitle.generate(data_inputs) else: pre_data = "" data_new = [pre_data] + data_text[1:] return data_new def predict_data_post_processing(text_list): text_list_sentence = [] # text_list_sentence.append([text_list[0][0], text_list[0][1]]) for i in range(len(text_list)): if text_list[i][2] != 0: text_list_sentence[-1][0] += text_list[i][0] else: text_list_sentence.append([text_list[i][0], text_list[i][1]]) return_list = {} sentence_one = [] sentence_id = text_list_sentence[0][1] for i in text_list_sentence: if i[1] == sentence_id: sentence_one.append(i[0]) else: return_list[sentence_id] = "。".join(sentence_one) sentence_id = i[1] sentence_one = [] sentence_one.append(i[0]) if sentence_one != []: return_list[sentence_id] = "。".join(sentence_one) return return_list # def main(text:list): # # text_list = paragraph_test(text) # # batch_data = batch_data_process(text_list) # # text_list = [] # # for i in batch_data: # # text_list.extend(i) # # return_list = predict_data_post_processing(text_list) # # return return_list def main(text: dict): text_list = paragraph_test(text) text_list_new = [] for i in text_list: pre = one_predict(i) text_list_new.append(pre) return_list = predict_data_post_processing(text_list_new) return return_list @app.route('/droprepeat/', methods=['POST']) def sentence(): print(request.remote_addr) texts = request.json["texts"] text_type = request.json["text_type"] print("原始语句" + str(texts)) # question = question.strip('。、!??') if isinstance(texts, dict): texts_list = [] y_pred_label_list = [] position_list = [] # texts = texts.replace('\'', '\"') if texts is None: return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} return jsonify(return_text) else: assert text_type in ['focus', 'chapter'] if text_type == 'focus': texts_list = main(texts) if text_type == 'chapter': texts_list = main(texts) return_text = {"texts": texts_list, "probabilities": None, "status_code": True} else: return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False} return jsonify(return_text) def classify(): # 调用模型,设置最大batch_size while True: if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 time.sleep(3) continue query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text data_dict_path = json.loads(query) path = data_dict_path['path'] # text_type = data_dict["text_type"] with open(path, encoding='utf8') as f1: # 加载文件的对象 data_dict = json.load(f1) query_id = data_dict['id'] texts = data_dict["text"] text_type = data_dict["text_type"] assert text_type in ['focus', 'chapter'] if text_type == 'focus': texts_list = main(texts) elif text_type == 'chapter': texts_list = main(texts) else: texts_list = [] return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} load_result_path = "./new_data_logs/{}.json".format(query_id) print("query_id: ", query_id) print("load_result_path: ", load_result_path) with open(load_result_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(return_text, f2, ensure_ascii=False, indent=4) debug_id_1 = 1 redis_.set(query_id, load_result_path, 86400) debug_id_2 = 2 redis_.srem(db_key_querying, query_id) debug_id_3 = 3 log.log('start at', 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) @app.route("/predict", methods=["POST"]) def handle_query(): print(request.remote_addr) texts = request.json["texts"] text_type = request.json["text_type"] if texts is None: return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402} return jsonify(return_text) if isinstance(texts, dict): id_ = str(uuid.uuid1()) # 为query生成唯一标识 print("uuid: ", uuid) d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id load_request_path = './request_data_logs/{}.json'.format(id_) with open(load_request_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(d, f2, ensure_ascii=False, indent=4) redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis redis_.sadd(db_key_querying, id_) return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200} print("ok") else: return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401} return jsonify(return_text) # 返回结果 t = Thread(target=classify) t.start() if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) app.run(host="0.0.0.0", port=14000, threaded=True, debug=False)