Browse Source

优化代码

master
majiahui@haimaqingfan.com 10 months ago
parent
commit
b3f5dd09f0
  1. 73
      flask_drop_rewrite_request.py

73
flask_drop_rewrite_request.py

@ -137,7 +137,6 @@ def dialog_line_parse(url, text):
print("{}】 Failed to get a proper response from remote " print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}" "server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text)) "".format(url, response.status_code, response.text))
print(text)
return {} return {}
@ -270,17 +269,23 @@ def uuid_search_mp(results):
result = uuid_search(uuid) result = uuid_search(uuid)
if result["code"] == 200: if result["code"] == 200:
if result["text"] != "":
results_list[i] = result["text"] results_list[i] = result["text"]
else:
results_list[i] = "Empty character"
time.sleep(3) time.sleep(3)
return results_list return results_list
def get_multiple_urls(urls): def get_multiple_urls(text_info):
input_values = [] input_values = []
input_index = []
for i in urls: for i in range(len(text_info)):
input_values.append(i[1]) if text_info[i][3] == True:
input_values.append(text_info[i][4])
input_index.append(i)
with concurrent.futures.ThreadPoolExecutor(100) as executor: with concurrent.futures.ThreadPoolExecutor(100) as executor:
# 使用map方法并发地调用worker_function # 使用map方法并发地调用worker_function
results_1 = list(executor.map(request_api_chatgpt, input_values)) results_1 = list(executor.map(request_api_chatgpt, input_values))
@ -289,10 +294,19 @@ def get_multiple_urls(urls):
# 使用map方法并发地调用worker_function # 使用map方法并发地调用worker_function
results = list(executor.map(uuid_search_mp, [results_1])) results = list(executor.map(uuid_search_mp, [results_1]))
return_list = [] # return_list = []
for i,j in zip(urls, results[0]): # for i,j in zip(urls, results[0]):
return_list.append([i, j]) # return_list.append([i, j])
return return_list return_dict = {}
for i, j in zip(input_index, results[0]):
return_dict[i] = j
for i in range(len(text_info)):
if i in return_dict:
text_info[i].append(return_dict[i])
else:
text_info[i].append(text_info[i][0])
return text_info
def chulipangban_test_1(snetence_id, text): def chulipangban_test_1(snetence_id, text):
@ -319,7 +333,7 @@ def chulipangban_test_1(snetence_id, text):
sentence_batch_list = [] sentence_batch_list = []
if is_chinese == False: if is_chinese == False:
__long_machine_en = StateMachine(long_cuter_en(max_len=20, min_len=3)) __long_machine_en = StateMachine(long_cuter_en(max_len=25, min_len=3))
m_input = EnSequence(text) m_input = EnSequence(text)
__long_machine_en.run(m_input) __long_machine_en.run(m_input)
for v in m_input.sentence_list(): for v in m_input.sentence_list():
@ -364,7 +378,6 @@ def chulipangban_test_1(snetence_id, text):
def paragraph_test(texts: dict): def paragraph_test(texts: dict):
text_new = [] text_new = []
for i, text in texts.items(): for i, text in texts.items():
print("text", text)
text_list = chulipangban_test_1(i, text) text_list = chulipangban_test_1(i, text)
text_new.extend(text_list) text_new.extend(text_list)
@ -451,7 +464,7 @@ def predict_data_post_processing(text_list):
# # text_list.extend(i) # # text_list.extend(i)
# # return_list = predict_data_post_processing(text_list) # # return_list = predict_data_post_processing(text_list)
# # return return_list # # return return_list
def post_sentence_ulit(sentence, text_info): def post_sentence_ulit(text_info):
''' '''
后处理 后处理
:param sentence: :param sentence:
@ -464,7 +477,7 @@ def post_sentence_ulit(sentence, text_info):
if_change = text_info[3] if_change = text_info[3]
if if_change == True: if if_change == True:
sentence = sentence.strip() sentence = text_info[-1].strip()
if "改写后:" in sentence: if "改写后:" in sentence:
sentence_lable_index = sentence.index("改写后:") sentence_lable_index = sentence.index("改写后:")
sentence = sentence[sentence_lable_index + 4:] sentence = sentence[sentence_lable_index + 4:]
@ -480,14 +493,13 @@ def post_sentence_ulit(sentence, text_info):
# sentence = sentence[:-1] + text_info[0][-1] # sentence = sentence[:-1] + text_info[0][-1]
else: else:
sentence = text_info[0] sentence = text_info[0]
return sentence return text_info[:4] + [sentence]
def has_chinese(s): def has_chinese(s):
return bool(re.search('[\u4e00-\u9fa5]', s)) return bool(re.search('[\u4e00-\u9fa5]', s))
def english_ulit(sentence): def english_ulit(sentence):
print("sentence", sentence)
sentence = str(sentence).strip() sentence = str(sentence).strip()
if_change = True if_change = True
@ -520,7 +532,7 @@ def chinese_ulit(sentence):
text = f"User: 任务:改写句子\n改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{sentence}\nAssistant:" text = f"User: 任务:改写句子\n改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{sentence}\nAssistant:"
else: else:
text = f"下面词不做任何变化:\n{sentence}" text = f"User:下面词不做任何变化:\n{sentence}\nAssistant:"
if_change = False if_change = False
return text, if_change return text, if_change
@ -571,10 +583,9 @@ def main(texts: dict):
# vllm预测 # vllm预测
for i in text_list: for i in text_list:
print("sen", i[0])
text, if_change = pre_sentence_ulit(i[0]) text, if_change = pre_sentence_ulit(i[0])
text_sentence.append(text) text_info.append([i[0], i[1], i[2], if_change, text])
text_info.append([i[0], i[1], i[2], if_change])
# outputs = llm.generate(text_sentence, sampling_params) # 调用模型 # outputs = llm.generate(text_sentence, sampling_params) # 调用模型
# #
@ -622,28 +633,14 @@ def main(texts: dict):
# [4, 'http://114.116.25.228:12000/predict', {'texts': '任务:改写句子\n改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n一是新时代“枫桥经验”对'}] # [4, 'http://114.116.25.228:12000/predict', {'texts': '任务:改写句子\n改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n一是新时代“枫桥经验”对'}]
# ] # ]
input_data = [] text_info = get_multiple_urls(text_info)
for i in range(len(text_sentence)):
# input_data.append([i, chatgpt_url, {"texts": text_sentence[i]}])
input_data.append([i, text_sentence[i]])
results = get_multiple_urls(input_data)
generated_text_list = [""] * len(input_data)
for url, result in results:
# print(f"Result for {url}: {result}")
generated_text_list[url[0]] = result
for i in range(len(text_info)):
text_info[i] = post_sentence_ulit(text_info[i])
for i in range(len(generated_text_list)): for i in range(len(text_info)):
# if len(text_list[i][0]) > 7: text_list_new.append([text_info[i][-1]] + text_info[i][1:3])
# generated_text_list[i] = post_sentence_ulit(generated_text_list[i])
# else:
# generated_text_list[i] = text_list[i][0]
generated_text_list[i] = post_sentence_ulit(generated_text_list[i], text_info[i])
for i, j in zip(generated_text_list, text_info):
text_list_new.append([i] + j[1:3])
return_list = predict_data_post_processing(text_list_new) return_list = predict_data_post_processing(text_list_new)
return return_list return return_list
@ -741,7 +738,7 @@ def handle_query():
return jsonify(return_text) return jsonify(return_text)
if isinstance(texts, dict): if isinstance(texts, dict):
id_ = str(uuid.uuid1()) # 为query生成唯一标识 id_ = str(uuid.uuid1()) # 为query生成唯一标识
print("uuid: ", uuid) print("uuid: ", id_)
d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id
load_request_path = './request_data_logs/{}.json'.format(id_) load_request_path = './request_data_logs/{}.json'.format(id_)

Loading…
Cancel
Save