From b3f5dd09f038802d32c5255f2b9d56e3417cc20b Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Wed, 7 Aug 2024 16:12:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flask_drop_rewrite_request.py | 75 +++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 39 deletions(-) diff --git a/flask_drop_rewrite_request.py b/flask_drop_rewrite_request.py index 05c0f75..96e184f 100644 --- a/flask_drop_rewrite_request.py +++ b/flask_drop_rewrite_request.py @@ -137,7 +137,6 @@ def dialog_line_parse(url, text): print("【{}】 Failed to get a proper response from remote " "server. Status Code: {}. Response: {}" "".format(url, response.status_code, response.text)) - print(text) return {} @@ -270,17 +269,23 @@ def uuid_search_mp(results): result = uuid_search(uuid) if result["code"] == 200: - results_list[i] = result["text"] + if result["text"] != "": + results_list[i] = result["text"] + else: + results_list[i] = "Empty character" time.sleep(3) return results_list -def get_multiple_urls(urls): +def get_multiple_urls(text_info): input_values = [] + input_index = [] - for i in urls: - input_values.append(i[1]) + for i in range(len(text_info)): + if text_info[i][3] == True: + input_values.append(text_info[i][4]) + input_index.append(i) with concurrent.futures.ThreadPoolExecutor(100) as executor: # 使用map方法并发地调用worker_function results_1 = list(executor.map(request_api_chatgpt, input_values)) @@ -289,10 +294,19 @@ def get_multiple_urls(urls): # 使用map方法并发地调用worker_function results = list(executor.map(uuid_search_mp, [results_1])) - return_list = [] - for i,j in zip(urls, results[0]): - return_list.append([i, j]) - return return_list + # return_list = [] + # for i,j in zip(urls, results[0]): + # return_list.append([i, j]) + return_dict = {} + for i, j in zip(input_index, results[0]): + return_dict[i] = j + + for i in range(len(text_info)): + if i in return_dict: + text_info[i].append(return_dict[i]) + else: + text_info[i].append(text_info[i][0]) + return text_info def chulipangban_test_1(snetence_id, text): @@ -319,7 +333,7 @@ def chulipangban_test_1(snetence_id, text): sentence_batch_list = [] if is_chinese == False: - __long_machine_en = StateMachine(long_cuter_en(max_len=20, min_len=3)) + __long_machine_en = StateMachine(long_cuter_en(max_len=25, min_len=3)) m_input = EnSequence(text) __long_machine_en.run(m_input) for v in m_input.sentence_list(): @@ -364,7 +378,6 @@ def chulipangban_test_1(snetence_id, text): def paragraph_test(texts: dict): text_new = [] for i, text in texts.items(): - print("text", text) text_list = chulipangban_test_1(i, text) text_new.extend(text_list) @@ -451,7 +464,7 @@ def predict_data_post_processing(text_list): # # text_list.extend(i) # # return_list = predict_data_post_processing(text_list) # # return return_list -def post_sentence_ulit(sentence, text_info): +def post_sentence_ulit(text_info): ''' 后处理 :param sentence: @@ -464,7 +477,7 @@ def post_sentence_ulit(sentence, text_info): if_change = text_info[3] if if_change == True: - sentence = sentence.strip() + sentence = text_info[-1].strip() if "改写后:" in sentence: sentence_lable_index = sentence.index("改写后:") sentence = sentence[sentence_lable_index + 4:] @@ -480,14 +493,13 @@ def post_sentence_ulit(sentence, text_info): # sentence = sentence[:-1] + text_info[0][-1] else: sentence = text_info[0] - return sentence + return text_info[:4] + [sentence] def has_chinese(s): return bool(re.search('[\u4e00-\u9fa5]', s)) def english_ulit(sentence): - print("sentence", sentence) sentence = str(sentence).strip() if_change = True @@ -520,7 +532,7 @@ def chinese_ulit(sentence): text = f"User: 任务:改写句子\n改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{sentence}\nAssistant:" else: - text = f"下面词不做任何变化:\n{sentence}" + text = f"User:下面词不做任何变化:\n{sentence}\nAssistant:" if_change = False return text, if_change @@ -571,10 +583,9 @@ def main(texts: dict): # vllm预测 for i in text_list: + print("sen", i[0]) text, if_change = pre_sentence_ulit(i[0]) - text_sentence.append(text) - text_info.append([i[0], i[1], i[2], if_change]) - + text_info.append([i[0], i[1], i[2], if_change, text]) # outputs = llm.generate(text_sentence, sampling_params) # 调用模型 # @@ -622,28 +633,14 @@ def main(texts: dict): # [4, 'http://114.116.25.228:12000/predict', {'texts': '任务:改写句子\n改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n一是新时代“枫桥经验”对'}] # ] - input_data = [] - for i in range(len(text_sentence)): - # input_data.append([i, chatgpt_url, {"texts": text_sentence[i]}]) - input_data.append([i, text_sentence[i]]) + text_info = get_multiple_urls(text_info) - results = get_multiple_urls(input_data) - generated_text_list = [""] * len(input_data) - for url, result in results: - # print(f"Result for {url}: {result}") - generated_text_list[url[0]] = result - - - for i in range(len(generated_text_list)): - # if len(text_list[i][0]) > 7: - # generated_text_list[i] = post_sentence_ulit(generated_text_list[i]) - # else: - # generated_text_list[i] = text_list[i][0] - generated_text_list[i] = post_sentence_ulit(generated_text_list[i], text_info[i]) + for i in range(len(text_info)): + text_info[i] = post_sentence_ulit(text_info[i]) - for i, j in zip(generated_text_list, text_info): - text_list_new.append([i] + j[1:3]) + for i in range(len(text_info)): + text_list_new.append([text_info[i][-1]] + text_info[i][1:3]) return_list = predict_data_post_processing(text_list_new) return return_list @@ -741,7 +738,7 @@ def handle_query(): return jsonify(return_text) if isinstance(texts, dict): id_ = str(uuid.uuid1()) # 为query生成唯一标识 - print("uuid: ", uuid) + print("uuid: ", id_) d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id load_request_path = './request_data_logs/{}.json'.format(id_)