Browse Source

完成t5单条预测版本,预测单条为json格式,每句话设置id

master
majiahui@haimaqingfan.com 2 years ago
parent
commit
0b74eea4df
  1. 21
      flask_predict_no_batch_t5.py

21
flask_predict_no_batch_t5.py

@ -74,7 +74,7 @@ def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num):
return chulipangban_return_list return chulipangban_return_list
def chulipangban_test_1(text, snetence_id): def chulipangban_test_1(snetence_id, text):
# 引号处理 # 引号处理
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
@ -120,11 +120,11 @@ def paragraph_test_(text:list, text_new:list):
# text_new_str = "".join(text_new) # text_new_str = "".join(text_new)
return text_new return text_new
def paragraph_test(text:list): def paragraph_test(texts:dict):
text_new = [] text_new = []
for i in range(len(text)): for i, text in texts.items():
text_list = chulipangban_test_1(text[i], i) text_list = chulipangban_test_1(i, text)
text_new.extend(text_list) text_new.extend(text_list)
# text_new_str = "".join(text_new) # text_new_str = "".join(text_new)
@ -186,25 +186,26 @@ def one_predict(data_text):
def predict_data_post_processing(text_list): def predict_data_post_processing(text_list):
text_list_sentence = [] text_list_sentence = []
# text_list_sentence.append([text_list[0][0], text_list[0][1]]) # text_list_sentence.append([text_list[0][0], text_list[0][1]])
for i in range(len(text_list)): for i in range(len(text_list)):
if text_list[i][2] != 0: if text_list[i][2] != 0:
text_list_sentence[-1][0] += text_list[i][0] text_list_sentence[-1][0] += text_list[i][0]
else: else:
text_list_sentence.append([text_list[i][0], text_list[i][1]]) text_list_sentence.append([text_list[i][0], text_list[i][1]])
return_list = [] return_list = {}
sentence_one = [] sentence_one = []
sentence_id = 0 sentence_id = "0"
for i in text_list_sentence: for i in text_list_sentence:
if i[1] == sentence_id: if i[1] == sentence_id:
sentence_one.append(i[0]) sentence_one.append(i[0])
else: else:
return_list[sentence_id] = "".join(sentence_one)
sentence_id = i[1] sentence_id = i[1]
return_list.append("".join(sentence_one))
sentence_one = [] sentence_one = []
sentence_one.append(i[0]) sentence_one.append(i[0])
if sentence_one != []: if sentence_one != []:
return_list.append("".join(sentence_one)) return_list[sentence_id] = "".join(sentence_one)
return return_list return return_list
@ -217,7 +218,7 @@ def predict_data_post_processing(text_list):
# # return_list = predict_data_post_processing(text_list) # # return_list = predict_data_post_processing(text_list)
# # return return_list # # return return_list
def main(text: list): def main(text: dict):
text_list = paragraph_test(text) text_list = paragraph_test(text)
text_list_new = [] text_list_new = []
for i in text_list: for i in text_list:
@ -235,7 +236,7 @@ def sentence():
# question = question.strip('。、!??') # question = question.strip('。、!??')
if isinstance(texts, list): if isinstance(texts, dict):
texts_list = [] texts_list = []
y_pred_label_list = [] y_pred_label_list = []
position_list = [] position_list = []

Loading…
Cancel
Save