# -*- encoding:utf-8 -*- import os from flask import Flask, jsonify from flask import request import requests import redis import uuid import json from threading import Thread import time import re import logging import concurrent.futures import socket logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=9, password="zhicheng123*") redis_ = redis.Redis(connection_pool=pool, decode_responses=True) db_key_query = 'query' db_key_querying = 'querying' batch_size = 32 app = Flask(__name__) app.config["JSON_AS_ASCII"] = False import logging pattern = r"[。]" RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") fuhao_end_sentence = ["。", ",", "?", "!", "…"] pantten_biaoti_0 = '^[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' pantten_biaoti_1 = '^第[一二三四五六七八九]章\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' pantten_biaoti_2 = '^[0-9.]+\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' pantten_biaoti_3 = '^[((][1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][)_)][、.]{0,}?\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' def get_host_ip(): """ 查询本机ip地址 :return: ip """ try: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(('8.8.8.8', 80)) ip = s.getsockname()[0] finally: s.close() return ip chatgpt_url_predict = "http://{}:12001/predict".format(str(get_host_ip())) chatgpt_url_search = "http://{}:12001/search".format(str(get_host_ip())) def smtp_f(name): # 在下面的代码行中使用断点来调试脚本。 import smtplib from email.mime.text import MIMEText from email.header import Header sender = '838878981@qq.com' # 发送邮箱 receivers = ['838878981@qq.com'] # 接收邮箱 auth_code = "jfqtutaiwrtdbcge" # 授权码 message = MIMEText('降重aigc出错,紧急', 'plain', 'utf-8') message['From'] = Header("Sender<%s>" % sender) # 发送者 message['To'] = Header("Receiver<%s>" % receivers[0]) # 接收者 subject = name message['Subject'] = Header(subject, 'utf-8') try: server = smtplib.SMTP_SSL('smtp.qq.com', 465) server.login(sender, auth_code) server.sendmail(sender, receivers, message.as_string()) print("邮件发送成功") server.close() except smtplib.SMTPException: print("Error: 无法发送邮件") class log: def __init__(self): pass def log(*args, **kwargs): format = '%Y/%m/%d-%H:%M:%S' format_h = '%Y-%m-%d' value = time.localtime(int(time.time())) dt = time.strftime(format, value) dt_log_file = time.strftime(format_h, value) log_file = 'log_file/access-%s' % dt_log_file + ".log" if not os.path.exists(log_file): with open(os.path.join(log_file), 'w', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) else: with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) def dialog_line_parse(url, text): """ 将数据输入模型进行分析并输出结果 :param url: 模型url :param text: 进入模型的数据 :return: 模型返回结果 """ response = requests.post( url, json=text, timeout=100000 ) if response.status_code == 200: return response.json() else: # logger.error( # "【{}】 Failed to get a proper response from remote " # "server. Status Code: {}. Response: {}" # "".format(url, response.status_code, response.text) # ) print("【{}】 Failed to get a proper response from remote " "server. Status Code: {}. Response: {}" "".format(url, response.status_code, response.text)) print(text) return {} def get_dialogs_index(line: str): """ 获取对话及其索引 :param line 文本 :return dialogs 对话内容 dialogs_index: 对话位置索引 other_index: 其他内容位置索引 """ dialogs = re.finditer(RE_DIALOG, line) dialogs_text = re.findall(RE_DIALOG, line) dialogs_index = [] for dialog in dialogs: all_ = [i for i in range(dialog.start(), dialog.end())] dialogs_index.extend(all_) other_index = [i for i in range(len(line)) if i not in dialogs_index] return dialogs_text, dialogs_index, other_index def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): fuhao = [",", "?", "!", "…"] dialogs_text, dialogs_index, other_index = get_dialogs_index(text) text_1 = text[:120] text_2 = text[120:] text_1_new = "" if text_2 == "": chulipangban_return_list.append([text_1, snetence_id, short_num]) return chulipangban_return_list for i in range(len(text_1) - 1, -1, -1): if text_1[i] in fuhao: if i in dialogs_index: continue text_1_new = text_1[:i] text_1_new += text_1[i] chulipangban_return_list.append([text_1_new, snetence_id, short_num]) if text_2 != "": if i + 1 != 120: text_2 = text_1[i + 1:] + text_2 break # else: # chulipangban_return_list.append(text_1) if text_1_new == "": chulipangban_return_list.append([text_1, snetence_id, short_num]) if text_2 != "": short_num += 1 chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) return chulipangban_return_list # def get_multiple_urls(urls): # with concurrent.futures.ThreadPoolExecutor() as executor: # future_to_url = {executor.submit(dialog_line_parse, url[1], url[2]): url for url in urls} # # # results = [] # for future in concurrent.futures.as_completed(future_to_url): # url = future_to_url[future] # try: # data = future.result() # results.append((url, data)) # except Exception as e: # results.append((url, f"Error: {str(e)}")) # return results def request_api_chatgpt(prompt): data = { "texts": prompt } response = requests.post( chatgpt_url_predict, json=data, timeout=100000 ) if response.status_code == 200: return response.json() else: # logger.error( # "【{}】 Failed to get a proper response from remote " # "server. Status Code: {}. Response: {}" # "".format(url, response.status_code, response.text) # ) print("Failed to get a proper response from remote " "server. Status Code: {}. Response: {}" "".format(response.status_code, response.text)) return {} def uuid_search(uuid): data = { "id": uuid } response = requests.post( chatgpt_url_search, json=data, timeout=100000 ) if response.status_code == 200: return response.json() else: # logger.error( # "【{}】 Failed to get a proper response from remote " # "server. Status Code: {}. Response: {}" # "".format(url, response.status_code, response.text) # ) print("Failed to get a proper response from remote " "server. Status Code: {}. Response: {}" "".format(response.status_code, response.text)) return {} def uuid_search_mp(results): results_list = [""] * len(results) while True: tiaochu_bool = True for i in results_list: if i == "": tiaochu_bool = False break if tiaochu_bool == True: break for i in range(len(results)): uuid = results[i]["texts"]["id"] result = uuid_search(uuid) if result["code"] == 200: results_list[i] = result["text"] time.sleep(3) return results_list def get_multiple_urls(urls): input_values = [] for i in urls: input_values.append(i[1]) with concurrent.futures.ThreadPoolExecutor(100) as executor: # 使用map方法并发地调用worker_function results_1 = list(executor.map(request_api_chatgpt, input_values)) with concurrent.futures.ThreadPoolExecutor(100) as executor: # 使用map方法并发地调用worker_function results = list(executor.map(uuid_search_mp, [results_1])) return_list = [] for i,j in zip(urls, results[0]): return_list.append([i, j]) return return_list def chulipangban_test_1(snetence_id, text): # 引号处理 dialogs_text, dialogs_index, other_index = get_dialogs_index(text) for dialogs_text_dan in dialogs_text: text_dan_list = text.split(dialogs_text_dan) text = dialogs_text_dan.join(text_dan_list) # text_new_str = "".join(text_new) sentence_list = text.split("\n") # sentence_list_new = [] # for i in sentence_list: # if i != "": # sentence_list_new.append(i) # sentence_list = sentence_list_new sentence_batch_list = [] sentence_batch_one = [] sentence_batch_length = 0 return_list = [] for sentence in sentence_list: if sentence != "": dan_sentence_list = [i for i in str(sentence).split("。") if i != ""] if len(dan_sentence_list) <= 3: sentence_batch_list.append([sentence, snetence_id, 0]) # sentence_pre = autotitle.gen_synonyms_short(sentence) # return_list.append(sentence_pre) else: shot_sen = 0 start = 0 for end in range(3, len(dan_sentence_list), 3): sentence_batch_list.append(["。".join(dan_sentence_list[start: end]) + "。", snetence_id, shot_sen]) start = end shot_sen += 1 sentence_batch_list.append(["。".join(dan_sentence_list[start: len(dan_sentence_list)]) + "。", snetence_id, shot_sen]) return sentence_batch_list def paragraph_test(texts: dict): text_new = [] for i, text in texts.items(): text_list = chulipangban_test_1(i, text) text_new.extend(text_list) # text_new_str = "".join(text_new) return text_new def batch_predict(batch_data_list): ''' 一个bacth数据预测 @param data_text: @return: ''' batch_data_list_new = [] batch_data_text_list = [] batch_data_snetence_id_list = [] for i in batch_data_list: batch_data_text_list.append(i[0]) batch_data_snetence_id_list.append(i[1:]) # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) batch_pre_data_list = batch_data_text_list for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): batch_data_list_new.append([text] + sentence_id) return batch_data_list_new def is_chinese(char): if '\u4e00' <= char <= '\u9fff': return True else: return False def predict_data_post_processing(text_list): print("text_list", text_list) text_list_sentence = [] # text_list_sentence.append([text_list[0][0], text_list[0][1]]) for i in range(len(text_list)): if text_list[i][2] != 0: text_list_sentence[-1][0] += text_list[i][0] else: text_list_sentence.append([text_list[i][0], text_list[i][1]]) return_list = {} sentence_one = [] sentence_id = text_list_sentence[0][1] for i in text_list_sentence: if i[1] == sentence_id: sentence_one.append(i[0]) else: return_list[sentence_id] = "".join(sentence_one) sentence_id = i[1] sentence_one = [] sentence_one.append(i[0]) if sentence_one != []: return_list[sentence_id] = "".join(sentence_one) return return_list # def main(text:list): # # text_list = paragraph_test(text) # # batch_data = batch_data_process(text_list) # # text_list = [] # # for i in batch_data: # # text_list.extend(i) # # return_list = predict_data_post_processing(text_list) # # return return_list def post_sentence_ulit(sentence, text_info): ''' 后处理 :param sentence: :return: ''' # if len(text_list[i][0]) > 7: # generated_text_list[i] = post_sentence_ulit(generated_text_list[i]) # else: # generated_text_list[i] = text_list[i][0] sentence = sentence.strip("\n").strip(" ") return sentence def has_chinese(s): return bool(re.search('[\u4e00-\u9fa5]', s)) def pre_sentence_ulit(sentence): ''' 预处理 :param sentence: :return: ''' sentence = str(sentence).strip() if_change = True text = f"User: 任务:降aigc率\n请用偏口语化改写句子,要求改写后的句子与原句差别较大,句子完成重新打乱重新描述,语义上可以有变动,让观点更明确,且内容更丰富,形容词增多,但是不能出现明显的逻辑错误和语法错误,不能有不合理的用词,用词必须符合汉语的常识习惯,但是句子格式和内容变化要大,且必须符合人类的书写习惯,语法上必须要规范,尽量多断句,尽量控制每个短句句子长度,能用同义词替换的部分尽量用同义词替换,在句子开头禁止出现连词,比如“首先”, “其次”等,尽量避免出现模糊表达和不确定性表达,比如“某大学”,“某地”,“某城市”等,改写后的句子长度不能少于原句的,需要改写的句子 “{sentence}”,直返会改写后的句子,不要返回其他内容\nAssistant:" return text, if_change def main(texts: dict): text_list = paragraph_test(texts) text_info = [] text_sentence = [] text_list_new = [] # for i in text_list: # pre = one_predict(i) # text_list_new.append(pre) # vllm预测 for i in text_list: text, if_change = pre_sentence_ulit(i[0]) text_sentence.append(text) text_info.append([i[0], i[1], i[2], if_change]) input_data = [] for i in range(len(text_sentence)): # input_data.append([i, chatgpt_url, {"texts": text_sentence[i]}]) input_data.append([i, text_sentence[i]]) results = get_multiple_urls(input_data) generated_text_list = [""] * len(input_data) for url, result in results: # print(f"Result for {url}: {result}") generated_text_list[url[0]] = result for i in range(len(generated_text_list)): generated_text_list[i] = post_sentence_ulit(generated_text_list[i], text_info[i]) for i, j in zip(generated_text_list, text_info): text_list_new.append([i] + j[1:3]) return_list = predict_data_post_processing(text_list_new) return return_list def classify(): # 调用模型,设置最大batch_size while True: if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 time.sleep(3) continue query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text print("query", query) data_dict_path = json.loads(query) path = data_dict_path['path'] # text_type = data_dict["text_type"] with open(path, encoding='utf8') as f1: # 加载文件的对象 data_dict = json.load(f1) query_id = data_dict['id'] texts = data_dict["text"] text_type = data_dict["text_type"] assert text_type in ['focus', 'chapter'] if text_type == 'focus': texts_list = main(texts) elif text_type == 'chapter': texts_list = main(texts) else: texts_list = [] if texts_list != []: return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} else: smtp_f("drop_aigc") return_text = {"texts": texts_list, "probabilities": None, "status_code": 400} load_result_path = "./new_data_logs/{}.json".format(query_id) print("query_id: ", query_id) print("load_result_path: ", load_result_path) with open(load_result_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(return_text, f2, ensure_ascii=False, indent=4) debug_id_1 = 1 redis_.set(query_id, load_result_path, 86400) debug_id_2 = 2 redis_.srem(db_key_querying, query_id) debug_id_3 = 3 log.log('start at', 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) @app.route("/predict", methods=["POST"]) def handle_query(): print(request.remote_addr) texts = request.json["texts"] text_type = request.json["text_type"] if texts is None: return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402} return jsonify(return_text) if isinstance(texts, dict): id_ = str(uuid.uuid1()) # 为query生成唯一标识 print("uuid: ", id_) d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id load_request_path = './request_data_logs/{}.json'.format(id_) print(load_request_path) with open(load_request_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(d, f2, ensure_ascii=False, indent=4) redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis redis_.sadd(db_key_querying, id_) return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200} print("ok") else: return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401} return jsonify(return_text) # 返回结果 t = Thread(target=classify) t.start() if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) app.run(host="0.0.0.0", port=14004, threaded=True, debug=False)