
14 changed files with 621 additions and 48 deletions
@ -0,0 +1,250 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/3 14:22 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
import os |
|||
# os.environ["TF_KERAS"] = "1" |
|||
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|||
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
|||
from flask import Flask, jsonify |
|||
from predict_t5 import autotitle |
|||
import re |
|||
import json |
|||
from tqdm import tqdm |
|||
|
|||
|
|||
|
|||
db_key_query = 'query' |
|||
db_key_result = 'result' |
|||
batch_size = 32 |
|||
|
|||
app = Flask(__name__) |
|||
app.config["JSON_AS_ASCII"] = False |
|||
|
|||
pattern = r"[。]" |
|||
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") |
|||
fuhao_end_sentence = ["。",",","?","!","…"] |
|||
|
|||
config = { |
|||
"batch_szie": 1000 |
|||
} |
|||
|
|||
|
|||
def get_dialogs_index(line: str): |
|||
""" |
|||
获取对话及其索引 |
|||
:param line 文本 |
|||
:return dialogs 对话内容 |
|||
dialogs_index: 对话位置索引 |
|||
other_index: 其他内容位置索引 |
|||
""" |
|||
dialogs = re.finditer(RE_DIALOG, line) |
|||
dialogs_text = re.findall(RE_DIALOG, line) |
|||
dialogs_index = [] |
|||
for dialog in dialogs: |
|||
all_ = [i for i in range(dialog.start(), dialog.end())] |
|||
dialogs_index.extend(all_) |
|||
other_index = [i for i in range(len(line)) if i not in dialogs_index] |
|||
|
|||
return dialogs_text, dialogs_index, other_index |
|||
|
|||
|
|||
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): |
|||
fuhao = [",","?","!","…"] |
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
text_1 = text[:120] |
|||
text_2 = text[120:] |
|||
text_1_new = "" |
|||
if text_2 == "": |
|||
chulipangban_return_list.append([text_1, snetence_id, short_num]) |
|||
return chulipangban_return_list |
|||
for i in range(len(text_1)-1, -1, -1): |
|||
if text_1[i] in fuhao: |
|||
if i in dialogs_index: |
|||
continue |
|||
text_1_new = text_1[:i] |
|||
text_1_new += text_1[i] |
|||
chulipangban_return_list.append([text_1_new, snetence_id, short_num]) |
|||
if text_2 != "": |
|||
if i+1 != 120: |
|||
text_2 = text_1[i+1:] + text_2 |
|||
break |
|||
# else: |
|||
# chulipangban_return_list.append(text_1) |
|||
if text_1_new == "": |
|||
chulipangban_return_list.append([text_1, snetence_id, short_num]) |
|||
if text_2 != "": |
|||
short_num += 1 |
|||
chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) |
|||
return chulipangban_return_list |
|||
|
|||
|
|||
def chulipangban_test_1(snetence_id, text): |
|||
# 引号处理 |
|||
|
|||
dialogs_text, dialogs_index, other_index = get_dialogs_index(text) |
|||
for dialogs_text_dan in dialogs_text: |
|||
text_dan_list = text.split(dialogs_text_dan) |
|||
if "。" in dialogs_text_dan: |
|||
dialogs_text_dan = str(dialogs_text_dan).replace("。", "&") |
|||
text = dialogs_text_dan.join(text_dan_list) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
|
|||
sentence_list = text.split("。") |
|||
# sentence_list_new = [] |
|||
# for i in sentence_list: |
|||
# if i != "": |
|||
# sentence_list_new.append(i) |
|||
# sentence_list = sentence_list_new |
|||
sentence_batch_list = [] |
|||
sentence_batch_one = [] |
|||
sentence_batch_length = 0 |
|||
return_list = [] |
|||
for sentence in sentence_list: |
|||
if len(sentence) < 120: |
|||
sentence_batch_length += len(sentence) |
|||
sentence_batch_list.append([sentence, snetence_id, 0]) |
|||
# sentence_pre = autotitle.gen_synonyms_short(sentence) |
|||
# return_list.append(sentence_pre) |
|||
else: |
|||
|
|||
sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) |
|||
for sentence_short in sentence_split_list: |
|||
sentence_batch_list.append(sentence_short) |
|||
return sentence_batch_list |
|||
|
|||
|
|||
def paragraph_test_(text:list, text_new:list): |
|||
|
|||
for i in range(len(text)): |
|||
text = chulipangban_test_1(text, i) |
|||
text = "。".join(text) |
|||
text_new.append(text) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
return text_new |
|||
|
|||
def paragraph_test(texts:dict): |
|||
|
|||
text_new = [] |
|||
for i, text in texts.items(): |
|||
try: |
|||
text_list = chulipangban_test_1(i, text) |
|||
except: |
|||
print(i, text) |
|||
continue |
|||
text_new.extend(text_list) |
|||
|
|||
# text_new_str = "".join(text_new) |
|||
return text_new |
|||
|
|||
|
|||
def batch_data_process(text_list): |
|||
sentence_batch_length = 0 |
|||
sentence_batch_one = [] |
|||
sentence_batch_list = [] |
|||
|
|||
for sentence in text_list: |
|||
sentence_batch_length += len(sentence[0]) |
|||
sentence_batch_one.append(sentence) |
|||
if sentence_batch_length > 500: |
|||
sentence_batch_length = 0 |
|||
sentence_ = sentence_batch_one.pop(-1) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
sentence_batch_one = [] |
|||
sentence_batch_one.append(sentence_) |
|||
sentence_batch_list.append(sentence_batch_one) |
|||
return sentence_batch_list |
|||
|
|||
def batch_predict(batch_data_list): |
|||
''' |
|||
一个bacth数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
batch_data_list_new = [] |
|||
batch_data_text_list = [] |
|||
batch_data_snetence_id_list = [] |
|||
for i in batch_data_list: |
|||
batch_data_text_list.append(i[0]) |
|||
batch_data_snetence_id_list.append(i[1:]) |
|||
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) |
|||
batch_pre_data_list = batch_data_text_list |
|||
for text,sentence_id in zip(batch_pre_data_list,batch_data_snetence_id_list): |
|||
batch_data_list_new.append([text] + sentence_id) |
|||
|
|||
return batch_data_list_new |
|||
|
|||
|
|||
def one_predict(data_text): |
|||
''' |
|||
一个条数据预测 |
|||
@param data_text: |
|||
@return: |
|||
''' |
|||
if data_text[0] != "": |
|||
data_inputs = data_text[0].replace("&", "。") |
|||
pre_data = autotitle.generate(data_inputs) |
|||
else: |
|||
pre_data = "" |
|||
data_new = [pre_data] + data_text[1:] |
|||
return data_new |
|||
|
|||
|
|||
def predict_data_post_processing(text_list): |
|||
text_list_sentence = [] |
|||
# text_list_sentence.append([text_list[0][0], text_list[0][1]]) |
|||
|
|||
for i in range(len(text_list)): |
|||
if text_list[i][2] != 0: |
|||
text_list_sentence[-1][0] += text_list[i][0] |
|||
else: |
|||
text_list_sentence.append([text_list[i][0], text_list[i][1]]) |
|||
|
|||
return_list = {} |
|||
sentence_one = [] |
|||
sentence_id = "0" |
|||
for i in text_list_sentence: |
|||
if i[1] == sentence_id: |
|||
sentence_one.append(i[0]) |
|||
else: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
sentence_id = i[1] |
|||
sentence_one = [] |
|||
sentence_one.append(i[0]) |
|||
if sentence_one != []: |
|||
return_list[sentence_id] = "。".join(sentence_one) |
|||
return return_list |
|||
|
|||
|
|||
# def main(text:list): |
|||
# # text_list = paragraph_test(text) |
|||
# # batch_data = batch_data_process(text_list) |
|||
# # text_list = [] |
|||
# # for i in batch_data: |
|||
# # text_list.extend(i) |
|||
# # return_list = predict_data_post_processing(text_list) |
|||
# # return return_list |
|||
|
|||
def main(text: dict): |
|||
text_list = paragraph_test(text) |
|||
text_list_new = [] |
|||
for i in tqdm(text_list): |
|||
pre = one_predict(i) |
|||
text_list_new.append(pre) |
|||
return_list = predict_data_post_processing(text_list_new) |
|||
return return_list |
|||
|
|||
|
|||
if __name__ == '__main__': |
|||
|
|||
filename = './data/yy_data.json' |
|||
with open(filename) as file_obj: |
|||
yy_data = json.load(file_obj) |
|||
rels = main(yy_data) |
@ -0,0 +1,31 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/3 14:55 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
import pandas as pd |
|||
import json |
|||
|
|||
|
|||
yy_data_1 = "../data/论文_yy_小说.xlsx" |
|||
yy_data_2 = "../data/论文_yy_小说_1.xlsx" |
|||
df_1 = pd.read_excel(yy_data_1).values.tolist() |
|||
df_2 = pd.read_excel(yy_data_2).values.tolist() |
|||
df = df_1 + df_2 |
|||
|
|||
return_data = {} |
|||
for i in range(len(df)): |
|||
return_data[str(i)] = df[i][0] |
|||
|
|||
# import json |
|||
# |
|||
# names = ['joker','joe','nacy','timi'] |
|||
# |
|||
filename='yy_data.json' |
|||
with open(filename, 'w') as file_obj: |
|||
json.dump(return_data,file_obj) |
|||
|
@ -0,0 +1,76 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/2 19:31 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
# |
|||
# import redis |
|||
# |
|||
# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0) |
|||
# redis_conn = redis.Redis(connection_pool=redis_pool) |
|||
# |
|||
# |
|||
# name_dict = { |
|||
# 'name_4' : 'Zarten_4', |
|||
# 'name_5' : 'Zarten_5' |
|||
# } |
|||
# redis_conn.mset(name_dict) |
|||
|
|||
import flask |
|||
import redis |
|||
import uuid |
|||
import json |
|||
from threading import Thread |
|||
import time |
|||
|
|||
app = flask.Flask(__name__) |
|||
pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
db_key_query = 'query' |
|||
db_key_result = 'result' |
|||
batch_size = 32 |
|||
|
|||
|
|||
|
|||
def classify(): # 调用模型,设置最大batch_size |
|||
while True: |
|||
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 |
|||
continue |
|||
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text |
|||
data_dict = json.loads(query) |
|||
query_id = data_dict['id'] |
|||
text = data_dict['text'] # 拼接若干text 为batch |
|||
result = text + "1111111" # 调用模型 |
|||
time.sleep(5) |
|||
# for (id_, res) in zip(query_ids, result): |
|||
# res['score'] = str(res['score']) |
|||
# redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 |
|||
# d = {"id": query_id, "text": result} |
|||
redis_.set(query_id, json.dumps(result)) # 加入redis |
|||
|
|||
@app.route("/predict", methods=["POST"]) |
|||
def handle_query(): |
|||
text = flask.request.json['text'] # 获取用户query中的文本 例如"I love you" |
|||
id_ = str(uuid.uuid1()) # 为query生成唯一标识 |
|||
print(id_) |
|||
d = {'id': id_, 'text': text} # 绑定文本和query id |
|||
redis_.rpush(db_key_query, json.dumps(d)) # 加入redis |
|||
# while True: |
|||
# result = redis_.get(id_) # 获取该query的模型结果 |
|||
# if result is not None: |
|||
# redis_.delete(id_) |
|||
# result_text = {'code': "200", 'data': result.decode('UTF-8')} |
|||
# break |
|||
result_text = {'id': id_, 'text': text} |
|||
return flask.jsonify(result_text) # 返回结果 |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
t = Thread(target=classify) |
|||
t.start() |
|||
app.run(debug=False, host='127.0.0.1', port=9000) |
@ -0,0 +1,9 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/2 16:40 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
@ -0,0 +1,24 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/6 17:58 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
# 并行工作线程数 |
|||
workers = 1 |
|||
# 监听内网端口5000【按需要更改】 |
|||
bind = '0.0.0.0:14001' |
|||
|
|||
worker_class = "gevent" |
|||
# 设置守护进程【关闭连接时,程序仍在运行】 |
|||
daemon = True |
|||
# 设置超时时间120s,默认为30s。按自己的需求进行设置 |
|||
timeout = 120 |
|||
# 设置访问日志和错误信息日志路径 |
|||
accesslog = './check_logs/acess.log' |
|||
errorlog = './check_logs/error.log' |
|||
# access_log_format = '%(h) - %(t)s - %(u)s - %(s)s %(H)s' |
|||
# errorlog = '-' # 记录到标准输出 |
@ -0,0 +1,15 @@ |
|||
# 并行工作线程数 |
|||
workers = 1 |
|||
# 监听内网端口5000【按需要更改】 |
|||
bind = '0.0.0.0:14000' |
|||
|
|||
worker_class = "gevent" |
|||
# 设置守护进程【关闭连接时,程序仍在运行】 |
|||
daemon = True |
|||
# 设置超时时间120s,默认为30s。按自己的需求进行设置 |
|||
timeout = 120 |
|||
# 设置访问日志和错误信息日志路径 |
|||
accesslog = './logs/acess.log' |
|||
errorlog = './logs/error.log' |
|||
# access_log_format = '%(h) - %(t)s - %(u)s - %(s)s %(H)s' |
|||
# errorlog = '-' # 记录到标准输出 |
@ -0,0 +1,75 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/2 19:31 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
# |
|||
# import redis |
|||
# |
|||
# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0) |
|||
# redis_conn = redis.Redis(connection_pool=redis_pool) |
|||
# |
|||
# |
|||
# name_dict = { |
|||
# 'name_4' : 'Zarten_4', |
|||
# 'name_5' : 'Zarten_5' |
|||
# } |
|||
# redis_conn.mset(name_dict) |
|||
|
|||
import flask |
|||
import redis |
|||
import uuid |
|||
import json |
|||
from threading import Thread |
|||
import time |
|||
|
|||
app = flask.Flask(__name__) |
|||
pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
db_key_query = 'query' |
|||
db_key_result = 'result' |
|||
batch_size = 32 |
|||
|
|||
|
|||
|
|||
def classify(batch_size): # 调用模型,设置最大batch_size |
|||
while True: |
|||
texts = [] |
|||
query_ids = [] |
|||
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 |
|||
continue |
|||
for i in range(min(redis_.llen(db_key_query), batch_size)): |
|||
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text |
|||
query_ids.append(json.loads(query)['id']) |
|||
texts.append(json.loads(query)['text']) # 拼接若干text 为batch |
|||
result = texts # 调用模型 |
|||
time.sleep(20) |
|||
for (id_, res) in zip(query_ids, result): |
|||
res['score'] = str(res['score']) |
|||
redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 |
|||
|
|||
|
|||
@app.route("/predict", methods=["POST"]) |
|||
def handle_query(): |
|||
text = flask.request.json['text'] # 获取用户query中的文本 例如"I love you" |
|||
id_ = str(uuid.uuid1()) # 为query生成唯一标识 |
|||
d = {'id': id_, 'text': text} # 绑定文本和query id |
|||
redis_.rpush(db_key_query, json.dumps(d)) # 加入redis |
|||
while True: |
|||
result = redis_.get(id_) # 获取该query的模型结果 |
|||
if result is not None: |
|||
redis_.delete(id_) |
|||
result_text = {'code': "200", 'data': result.decode('UTF-8')} |
|||
break |
|||
return flask.jsonify(result_text) # 返回结果 |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
t = Thread(target=classify, args=(batch_size,)) |
|||
t.start() |
|||
app.run(debug=False, host='127.0.0.1', port=9000) |
@ -0,0 +1,53 @@ |
|||
# -*- coding: utf-8 -*- |
|||
|
|||
""" |
|||
@Time : 2023/3/2 19:31 |
|||
@Author : |
|||
@FileName: |
|||
@Software: |
|||
@Describe: |
|||
""" |
|||
# |
|||
# import redis |
|||
# |
|||
# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0) |
|||
# redis_conn = redis.Redis(connection_pool=redis_pool) |
|||
# |
|||
# |
|||
# name_dict = { |
|||
# 'name_4' : 'Zarten_4', |
|||
# 'name_5' : 'Zarten_5' |
|||
# } |
|||
# redis_conn.mset(name_dict) |
|||
|
|||
import flask |
|||
import redis |
|||
import uuid |
|||
import json |
|||
from threading import Thread |
|||
import time |
|||
|
|||
app = flask.Flask(__name__) |
|||
pool = redis.ConnectionPool(host='localhost', port=6379, max_connections=50, db=1) |
|||
redis_ = redis.Redis(connection_pool=pool, decode_responses=True) |
|||
|
|||
|
|||
@app.route("/search", methods=["POST"]) |
|||
def handle_query(): |
|||
id_ = flask.request.json['id'] # 获取用户query中的文本 例如"I love you" |
|||
result = redis_.get(id_) # 获取该query的模型结果 |
|||
if result is not None: |
|||
redis_.delete(id_) |
|||
result_dict = result.decode('UTF-8') |
|||
result_dict = json.loads(result_dict) |
|||
texts = result_dict["texts"] |
|||
probabilities = result_dict["probabilities"] |
|||
status_code = result_dict["status_code"] |
|||
result_text = {'code': status_code, 'text': texts, 'probabilities': probabilities} |
|||
else: |
|||
result_text = {'code': "201", 'text': "", 'probabilities': None} |
|||
return flask.jsonify(result_text) # 返回结果 |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
app.run(debug=False, host='0.0.0.0', port=14001) |
@ -0,0 +1 @@ |
|||
gunicorn flask_predict_no_batch_t5:app -c gunicorn_config.py |
@ -0,0 +1 @@ |
|||
gunicorn redis_check_uuid:app -c gunicorn_check_uuid_config.py |
Loading…
Reference in new issue