From 9fc3021bce0bde37d741c6e16287ef5852461ef2 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Tue, 31 Oct 2023 14:50:11 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=BE=A9=E9=A2=84=E5=A4=84=E7=90=86bu?= =?UTF-8?q?g?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flask_drop_rewrite_request.py | 488 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 488 insertions(+) create mode 100644 flask_drop_rewrite_request.py diff --git a/flask_drop_rewrite_request.py b/flask_drop_rewrite_request.py new file mode 100644 index 0000000..938e50d --- /dev/null +++ b/flask_drop_rewrite_request.py @@ -0,0 +1,488 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +from flask import Flask, jsonify +from flask import request +import requests +import redis +import uuid +import json +from threading import Thread +import time +import re +import logging +from vllm import LLM, SamplingParams + + +logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) + +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=7, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_querying = 'querying' +db_key_queryset = 'queryset' +batch_size = 32 + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +import logging + +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。", ",", "?", "!", "…"] +pantten_biaoti_0 = '^[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' +pantten_biaoti_1 = '^第[一二三四五六七八九]章\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' +pantten_biaoti_2 = '^[0-9.]+\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' +pantten_biaoti_3 = '^[((][1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][)_)][、.]{0,}?\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' + + + +class log: + def __init__(self): + pass + + def log(*args, **kwargs): + format = '%Y/%m/%d-%H:%M:%S' + format_h = '%Y-%m-%d' + value = time.localtime(int(time.time())) + dt = time.strftime(format, value) + dt_log_file = time.strftime(format_h, value) + log_file = 'log_file/access-%s' % dt_log_file + ".log" + if not os.path.exists(log_file): + with open(os.path.join(log_file), 'w', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + else: + with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=100000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return {} + + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): + fuhao = [",", "?", "!", "…"] + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + if text_2 == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + return chulipangban_return_list + for i in range(len(text_1) - 1, -1, -1): + if text_1[i] in fuhao: + if i in dialogs_index: + continue + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, snetence_id, short_num]) + if text_2 != "": + if i + 1 != 120: + text_2 = text_1[i + 1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(snetence_id, text): + # 引号处理 + + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + for dialogs_text_dan in dialogs_text: + text_dan_list = text.split(dialogs_text_dan) + text = dialogs_text_dan.join(text_dan_list) + + # text_new_str = "".join(text_new) + + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + + for sentence in sentence_list[:-1]: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence + "。", snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) + for sentence_short in sentence_split_list[:-1]: + sentence_batch_list.append(sentence_short) + sentence_split_list[-1][0] = sentence_split_list[-1][0] + "。" + sentence_batch_list.append(sentence_split_list[-1]) + + if sentence_list[-1] != "": + if len(sentence_list[-1]) < 120: + sentence_batch_length += len(sentence_list[-1]) + sentence_batch_list.append([sentence_list[-1], snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + sentence_split_list = chulichangju_1(sentence_list[-1], snetence_id, [], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + + return sentence_batch_list + + +def paragraph_test(texts: dict): + text_new = [] + for i, text in texts.items(): + text_list = chulipangban_test_1(i, text) + text_new.extend(text_list) + + # text_new_str = "".join(text_new) + return text_new + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def predict_data_post_processing(text_list): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + + for i in range(len(text_list)): + if text_list[i][2] != 0: + text_list_sentence[-1][0] += text_list[i][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = {} + sentence_one = [] + sentence_id = text_list_sentence[0][1] + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + return_list[sentence_id] = "".join(sentence_one) + sentence_id = i[1] + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list[sentence_id] = "".join(sentence_one) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list +def post_sentence_ulit(sentence, text_info): + ''' + 后处理 + :param sentence: + :return: + ''' + # if len(text_list[i][0]) > 7: + # generated_text_list[i] = post_sentence_ulit(generated_text_list[i]) + # else: + # generated_text_list[i] = text_list[i][0] + if_change = text_info[3] + + if if_change == True: + if "改写后:" in sentence: + sentence_lable_index = sentence.index("改写后:") + sentence = sentence[sentence_lable_index + 4:] + if sentence[-1] == "\n": + sentence = sentence[:-1] + else: + sentence = text_info[0] + return sentence + + +def pre_sentence_ulit(sentence): + ''' + 预处理 + :param sentence: + :return: + ''' + sentence = str(sentence).strip() + if_change = True + if len(sentence) > 7: + text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(sentence) + else: + text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(sentence) + if_change = False + return text, if_change + + result_biaoti_list_0 = re.findall(pantten_biaoti_0, sentence) + result_biaoti_list_1 = re.findall(pantten_biaoti_1, sentence) + result_biaoti_list_2 = re.findall(pantten_biaoti_2, sentence) + result_biaoti_list_3 = re.findall(pantten_biaoti_3, sentence) + + if list(set(result_biaoti_list_0 + result_biaoti_list_1 + result_biaoti_list_2 + result_biaoti_list_3)) != []: + if_change = False + return text, if_change + + return text, if_change + + +def main(texts: dict): + text_list = paragraph_test(texts) + + text_info = [] + text_sentence = [] + text_list_new = [] + + # for i in text_list: + # pre = one_predict(i) + # text_list_new.append(pre) + + # vllm预测 + for i in text_list: + text, if_change = pre_sentence_ulit(i[0]) + text_sentence.append(text) + text_info.append([i[0], i[1], i[2], if_change]) + + + # outputs = llm.generate(text_sentence, sampling_params) # 调用模型 + # + # generated_text_list = [""] * len(text_sentence) + # + # # generated_text_list = ["" if len(i[0]) > 5 else i[0] for i in text_list] + # + # for i, output in enumerate(outputs): + # index = output.request_id + # generated_text = output.outputs[0].text + # generated_text_list[int(index)] = generated_text + generated_text_list = dialog_line_parse( + "http://192.168.31.145:14010/predict", + { + "texts":text_sentence + } + )["resilt"] + + for i in range(len(generated_text_list)): + # if len(text_list[i][0]) > 7: + # generated_text_list[i] = post_sentence_ulit(generated_text_list[i]) + # else: + # generated_text_list[i] = text_list[i][0] + generated_text_list[i] = post_sentence_ulit(generated_text_list[i], text_info[i]) + + for i, j in zip(generated_text_list, text_info): + text_list_new.append([i] + j[1:3]) + + return_list = predict_data_post_processing(text_list_new) + return return_list + + +# @app.route('/droprepeat/', methods=['POST']) +# def sentence(): +# print(request.remote_addr) +# texts = request.json["texts"] +# text_type = request.json["text_type"] +# print("原始语句" + str(texts)) +# # question = question.strip('。、!??') +# +# if isinstance(texts, dict): +# texts_list = [] +# y_pred_label_list = [] +# position_list = [] +# +# # texts = texts.replace('\'', '\"') +# if texts is None: +# return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} +# return jsonify(return_text) +# else: +# assert text_type in ['focus', 'chapter'] +# if text_type == 'focus': +# texts_list = main(texts) +# if text_type == 'chapter': +# texts_list = main(texts) +# return_text = {"texts": texts_list, "probabilities": None, "status_code": True} +# else: +# return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False} +# return jsonify(return_text) + + +def classify(): # 调用模型,设置最大batch_size + while True: + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + time.sleep(3) + continue + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + data_dict_path = json.loads(query) + path = data_dict_path['path'] + # text_type = data_dict["text_type"] + + with open(path, encoding='utf8') as f1: + # 加载文件的对象 + data_dict = json.load(f1) + + query_id = data_dict['id'] + texts = data_dict["text"] + text_type = data_dict["text_type"] + + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + elif text_type == 'chapter': + texts_list = main(texts) + else: + texts_list = [] + + return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} + load_result_path = "./new_data_logs/{}.json".format(query_id) + + print("query_id: ", query_id) + print("load_result_path: ", load_result_path) + + with open(load_result_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(return_text, f2, ensure_ascii=False, indent=4) + debug_id_1 = 1 + redis_.set(query_id, load_result_path, 86400) + debug_id_2 = 2 + redis_.srem(db_key_querying, query_id) + debug_id_3 = 3 + log.log('start at', + 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( + query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + print(request.remote_addr) + texts = request.json["texts"] + text_type = request.json["text_type"] + if texts is None: + return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402} + return jsonify(return_text) + if isinstance(texts, dict): + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + print("uuid: ", uuid) + d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id + + load_request_path = './request_data_logs/{}.json'.format(id_) + with open(load_request_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(d, f2, ensure_ascii=False, indent=4) + redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis + redis_.sadd(db_key_querying, id_) + redis_.sadd(db_key_queryset, id_) + return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200} + print("ok") + else: + return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401} + return jsonify(return_text) # 返回结果 + + +t = Thread(target=classify) +t.start() + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) + app.run(host="0.0.0.0", port=14004, threaded=True, debug=False)