Browse Source

第一次提交

master
majiahui@haimaqingfan.com 2 years ago
commit
b19375c6ff
  1. 0
      README.md
  2. 83
      articles_directory_predict.py
  3. 414
      flask_predict_batch_mistral.py
  4. 67
      flask_predict_mistral_vllm.py
  5. 24
      main.py
  6. 87
      redis_check_uuid_mistral.py
  7. 1
      run_app_nohub_flask_predict_batch_mistral.sh
  8. 1
      run_app_nohub_search_redis.sh
  9. 122
      vllm_predict_batch.py

0
README.md

83
articles_directory_predict.py

@ -0,0 +1,83 @@
from flask import Flask, jsonify
from flask import request
from transformers import pipeline
import redis
import uuid
import json
from threading import Thread
from vllm import LLM, SamplingParams
import time
import threading
import time
import concurrent.futures
import requests
import socket
def get_host_ip():
"""
查询本机ip地址
:return: ip
"""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
finally:
s.close()
return ip
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
def dialog_line_parse(url, text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
response = requests.post(
url,
json=text,
timeout=1000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text))
print(text)
return []
@app.route("/articles_directory", methods=["POST"])
def articles_directory():
text = request.json["texts"] # 获取用户query中的文本 例如"I love you"
nums = request.json["nums"]
nums = int(nums)
url = "http://{}:18001/predict".format(str(get_host_ip()))
input_data = []
for i in range(nums):
input_data.append([url, {"texts": "You are a helpful assistant.\n\nUser:{}\nAssistant:".format(text)}])
with concurrent.futures.ThreadPoolExecutor() as executor:
# 使用submit方法将任务提交给线程池,并获取Future对象
futures = [executor.submit(dialog_line_parse, i[0], i[1]) for i in input_data]
# 使用as_completed获取已完成的任务,并获取返回值
results = [future.result() for future in concurrent.futures.as_completed(futures)]
return jsonify(results) # 返回结果
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=18000)

414
flask_predict_batch_mistral.py

@ -0,0 +1,414 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from flask import Flask, jsonify
from flask import request
import requests
import redis
import uuid
import json
from threading import Thread
import time
import re
import logging
from vllm import LLM, SamplingParams
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=7, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
db_key_queryset = 'queryset'
batch_size = 32
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
import logging
pattern = r"[。]"
RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”")
fuhao_end_sentence = ["", "", "", "", ""]
# 加载模型
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=1.1,stop="</s>", max_tokens=4096)
models_path = "/home/majiahui/model-llm/openbuddy-mistral-7b-v13.1"
llm = LLM(model=models_path, tokenizer_mode="slow")
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def get_dialogs_index(line: str):
"""
获取对话及其索引
:param line 文本
:return dialogs 对话内容
dialogs_index: 对话位置索引
other_index: 其他内容位置索引
"""
dialogs = re.finditer(RE_DIALOG, line)
dialogs_text = re.findall(RE_DIALOG, line)
dialogs_index = []
for dialog in dialogs:
all_ = [i for i in range(dialog.start(), dialog.end())]
dialogs_index.extend(all_)
other_index = [i for i in range(len(line)) if i not in dialogs_index]
return dialogs_text, dialogs_index, other_index
def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num):
fuhao = ["", "", "", ""]
dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
text_1 = text[:120]
text_2 = text[120:]
text_1_new = ""
if text_2 == "":
chulipangban_return_list.append([text_1, snetence_id, short_num])
return chulipangban_return_list
for i in range(len(text_1) - 1, -1, -1):
if text_1[i] in fuhao:
if i in dialogs_index:
continue
text_1_new = text_1[:i]
text_1_new += text_1[i]
chulipangban_return_list.append([text_1_new, snetence_id, short_num])
if text_2 != "":
if i + 1 != 120:
text_2 = text_1[i + 1:] + text_2
break
# else:
# chulipangban_return_list.append(text_1)
if text_1_new == "":
chulipangban_return_list.append([text_1, snetence_id, short_num])
if text_2 != "":
short_num += 1
chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num)
return chulipangban_return_list
def chulipangban_test_1(snetence_id, text):
# 引号处理
dialogs_text, dialogs_index, other_index = get_dialogs_index(text)
for dialogs_text_dan in dialogs_text:
text_dan_list = text.split(dialogs_text_dan)
text = dialogs_text_dan.join(text_dan_list)
# text_new_str = "".join(text_new)
sentence_list = text.split("")
# sentence_list_new = []
# for i in sentence_list:
# if i != "":
# sentence_list_new.append(i)
# sentence_list = sentence_list_new
sentence_batch_list = []
sentence_batch_one = []
sentence_batch_length = 0
return_list = []
for sentence in sentence_list[:-1]:
if len(sentence) < 120:
sentence_batch_length += len(sentence)
sentence_batch_list.append([sentence + "", snetence_id, 0])
# sentence_pre = autotitle.gen_synonyms_short(sentence)
# return_list.append(sentence_pre)
else:
sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0)
for sentence_short in sentence_split_list[:-1]:
sentence_batch_list.append(sentence_short)
sentence_batch_list.append(sentence_split_list[-1] + "")
if sentence_list[:-1] != "":
if len(sentence_list[-1]) < 120:
sentence_batch_length += len(sentence_list[-1])
sentence_batch_list.append([sentence_list[-1], snetence_id, 0])
# sentence_pre = autotitle.gen_synonyms_short(sentence)
# return_list.append(sentence_pre)
else:
sentence_split_list = chulichangju_1(sentence_list[-1], snetence_id, [], 0)
for sentence_short in sentence_split_list:
sentence_batch_list.append(sentence_short)
return sentence_batch_list
def paragraph_test(texts: dict):
text_new = []
for i, text in texts.items():
text_list = chulipangban_test_1(i, text)
text_new.extend(text_list)
# text_new_str = "".join(text_new)
return text_new
def batch_data_process(text_list):
sentence_batch_length = 0
sentence_batch_one = []
sentence_batch_list = []
for sentence in text_list:
sentence_batch_length += len(sentence[0])
sentence_batch_one.append(sentence)
if sentence_batch_length > 500:
sentence_batch_length = 0
sentence_ = sentence_batch_one.pop(-1)
sentence_batch_list.append(sentence_batch_one)
sentence_batch_one = []
sentence_batch_one.append(sentence_)
sentence_batch_list.append(sentence_batch_one)
return sentence_batch_list
def batch_predict(batch_data_list):
'''
一个bacth数据预测
@param data_text:
@return:
'''
batch_data_list_new = []
batch_data_text_list = []
batch_data_snetence_id_list = []
for i in batch_data_list:
batch_data_text_list.append(i[0])
batch_data_snetence_id_list.append(i[1:])
# batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list)
batch_pre_data_list = batch_data_text_list
for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list):
batch_data_list_new.append([text] + sentence_id)
return batch_data_list_new
def predict_data_post_processing(text_list):
text_list_sentence = []
# text_list_sentence.append([text_list[0][0], text_list[0][1]])
for i in range(len(text_list)):
if text_list[i][2] != 0:
text_list_sentence[-1][0] += text_list[i][0]
else:
text_list_sentence.append([text_list[i][0], text_list[i][1]])
return_list = {}
sentence_one = []
sentence_id = text_list_sentence[0][1]
for i in text_list_sentence:
if i[1] == sentence_id:
sentence_one.append(i[0])
else:
return_list[sentence_id] = "".join(sentence_one)
sentence_id = i[1]
sentence_one = []
sentence_one.append(i[0])
if sentence_one != []:
return_list[sentence_id] = "".join(sentence_one)
return return_list
# def main(text:list):
# # text_list = paragraph_test(text)
# # batch_data = batch_data_process(text_list)
# # text_list = []
# # for i in batch_data:
# # text_list.extend(i)
# # return_list = predict_data_post_processing(text_list)
# # return return_list
def pre_sentence_ulit(sentence):
if "改写后:" in sentence:
sentence_lable_index = sentence.index("改写后:")
sentence = sentence[sentence_lable_index + 4:]
return sentence
def main(texts: dict):
text_list = paragraph_test(texts)
text_info = []
text_sentence = []
text_list_new = []
# for i in text_list:
# pre = one_predict(i)
# text_list_new.append(pre)
# vllm预测
for i in text_list:
if len(i[0]) > 7:
text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(i[0])
else:
text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(i[0])
text_sentence.append(text)
text_info.append([i[1], i[2]])
outputs = llm.generate(text_sentence, sampling_params) # 调用模型
generated_text_list = [""] * len(text_sentence)
# generated_text_list = ["" if len(i[0]) > 5 else i[0] for i in text_list]
for i, output in enumerate(outputs):
index = output.request_id
generated_text = output.outputs[0].text
generated_text = pre_sentence_ulit(generated_text)
generated_text_list[int(index)] = generated_text
for i in range(len(text_list)):
if len(text_list[i][0]) > 7:
continue
else:
generated_text_list[i] = text_list[i][0]
for i, j in zip(generated_text_list, text_info):
text_list_new.append([i] + j)
return_list = predict_data_post_processing(text_list_new)
return return_list
# @app.route('/droprepeat/', methods=['POST'])
# def sentence():
# print(request.remote_addr)
# texts = request.json["texts"]
# text_type = request.json["text_type"]
# print("原始语句" + str(texts))
# # question = question.strip('。、!??')
#
# if isinstance(texts, dict):
# texts_list = []
# y_pred_label_list = []
# position_list = []
#
# # texts = texts.replace('\'', '\"')
# if texts is None:
# return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False}
# return jsonify(return_text)
# else:
# assert text_type in ['focus', 'chapter']
# if text_type == 'focus':
# texts_list = main(texts)
# if text_type == 'chapter':
# texts_list = main(texts)
# return_text = {"texts": texts_list, "probabilities": None, "status_code": True}
# else:
# return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False}
# return jsonify(return_text)
def classify(): # 调用模型,设置最大batch_size
while True:
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(3)
continue
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
data_dict_path = json.loads(query)
path = data_dict_path['path']
# text_type = data_dict["text_type"]
with open(path, encoding='utf8') as f1:
# 加载文件的对象
data_dict = json.load(f1)
query_id = data_dict['id']
texts = data_dict["text"]
text_type = data_dict["text_type"]
assert text_type in ['focus', 'chapter']
if text_type == 'focus':
texts_list = main(texts)
elif text_type == 'chapter':
texts_list = main(texts)
else:
texts_list = []
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
load_result_path = "./new_data_logs/{}.json".format(query_id)
print("query_id: ", query_id)
print("load_result_path: ", load_result_path)
with open(load_result_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(return_text, f2, ensure_ascii=False, indent=4)
debug_id_1 = 1
redis_.set(query_id, load_result_path, 86400)
debug_id_2 = 2
redis_.srem(db_key_querying, query_id)
debug_id_3 = 3
log.log('start at',
'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format(
query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3))
@app.route("/predict", methods=["POST"])
def handle_query():
print(request.remote_addr)
texts = request.json["texts"]
text_type = request.json["text_type"]
if texts is None:
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402}
return jsonify(return_text)
if isinstance(texts, dict):
id_ = str(uuid.uuid1()) # 为query生成唯一标识
print("uuid: ", uuid)
d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id
load_request_path = './request_data_logs/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(d, f2, ensure_ascii=False, indent=4)
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
redis_.sadd(db_key_querying, id_)
redis_.sadd(db_key_queryset, id_)
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
print("ok")
else:
return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401}
return jsonify(return_text) # 返回结果
t = Thread(target=classify)
t.start()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
app.run(host="0.0.0.0", port=14002, threaded=True, debug=False)

67
flask_predict_mistral_vllm.py

@ -0,0 +1,67 @@
import flask
from transformers import pipeline
import redis
import uuid
import json
from threading import Thread
import time
app = flask.Flask(__name__)
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=5, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_result = 'result'
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=1.1,stop="</s>", max_tokens=4096)
models_path = "/home/majiahui/model-llm/openbuddy-mistral-7b-v13.1"
llm = LLM(model=models_path, tokenizer_mode="slow")
def mistral_vllm_models(texts):
outputs = llm.generate(texts, sampling_params) # 调用模型
generated_text_list = [""] * len(texts)
# generated_text_list = ["" if len(i[0]) > 5 else i[0] for i in text_list]
for i, output in enumerate(outputs):
index = output.request_id
generated_text = output.outputs[0].text
generated_text_list[int(index)] = generated_text
def classify(batch_size): # 调用模型,设置最大batch_size
while True:
texts = []
query_ids = []
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
continue
for i in range(min(redis_.llen(db_key_query), batch_size)):
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
query_ids.append(json.loads(query)['id'])
texts.append(json.loads(query)['text']) # 拼接若干text 为batch
result = mistral_vllm_models(texts) # 调用模型
for (id_, res) in zip(query_ids, result):
res['score'] = str(res['score'])
redis_.set(id_, json.dumps(res)) # 将模型结果送回队列
@app.route("/predict", methods=["POST"])
def handle_query():
text = flask.request.form['text'] # 获取用户query中的文本 例如"I love you"
id_ = str(uuid.uuid1()) # 为query生成唯一标识
d = {'id': id_, 'text': text} # 绑定文本和query id
redis_.rpush(db_key_query, json.dumps(d)) # 加入redis
while True:
result = redis_.get(id_) # 获取该query的模型结果
if result is not None:
redis_.delete(id_)
result_text = {'code': "200", 'data': result.decode('UTF-8')}
break
return flask.jsonify(result_text) # 返回结果
if __name__ == "__main__":
t = Thread(target=classify)
t.start()
app.run(debug=False, host='127.0.0.1', port=9000)

24
main.py

@ -0,0 +1,24 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from vllm import LLM, SamplingParams
# Sample prompts.
prompts = [
"You are a helpful assistant.\n\nUser:张亮的爸爸叫张明,张明的爸爸有三个孩子,大儿子叫张大,二儿子叫张昊,三儿子叫什么?\nAssistant:",
"You are a helpful assistant.\n\nUser:你好\nAssistant:",
"You are a helpful assistant.\n\nUser:1+1等于几\nAssistant:",
"You are a helpful assistant.\n\nUser:你是谁\nAssistant:",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0, top_p=1, presence_penalty=0.9, max_tokens=1024)
# Create an LLM.
llm = LLM(model="/home/majiahui/project/models-llm/openbuddy-mistral-7b-v13.1", trust_remote_code=True)
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

87
redis_check_uuid_mistral.py

@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
"""
@Time : 2023/3/2 19:31
@Author :
@FileName:
@Software:
@Describe:
"""
#
# import redis
#
# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0)
# redis_conn = redis.Redis(connection_pool=redis_pool)
#
#
# name_dict = {
# 'name_4' : 'Zarten_4',
# 'name_5' : 'Zarten_5'
# }
# redis_conn.mset(name_dict)
import flask
import redis
import uuid
import json
from threading import Thread
import time
app = flask.Flask(__name__)
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=7, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
@app.route("/search", methods=["POST"])
def handle_query():
id_ = flask.request.json['id'] # 获取用户query中的文本 例如"I love you"
result = redis_.get(id_) # 获取该query的模型结果
if result is not None:
# redis_.delete(id_)
result_path = result.decode('UTF-8')
with open(result_path, encoding='utf8') as f1:
# 加载文件的对象
result_dict = json.load(f1)
texts = result_dict["texts"]
probabilities = result_dict["probabilities"]
result_text = {'code': 200, 'text': texts, 'probabilities': probabilities}
else:
querying_list = list(redis_.smembers("querying"))
querying_set = set()
for i in querying_list:
querying_set.add(i.decode())
querying_bool = False
if id_ in querying_set:
querying_bool = True
query_list_json = redis_.lrange(db_key_query, 0, -1)
query_set_ids = set()
for i in query_list_json:
data_dict = json.loads(i)
query_id = data_dict['id']
query_set_ids.add(query_id)
query_bool = False
if id_ in query_set_ids:
query_bool = True
if querying_bool == True and query_bool == True:
result_text = {'code': "201", 'text': "", 'probabilities': None}
elif querying_bool == True and query_bool == False:
result_text = {'code': "202", 'text': "", 'probabilities': None}
else:
result_text = {'code': "203", 'text': "", 'probabilities': None}
load_request_path = './request_data_logs_203/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(result_text, f2, ensure_ascii=False, indent=4)
return flask.jsonify(result_text) # 返回结果
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=14003)

1
run_app_nohub_flask_predict_batch_mistral.sh

@ -0,0 +1 @@
nohup python flask_predict_batch_mistral.py > myout.flask_predict_batch_mistral.logs 2>&1 &

1
run_app_nohub_search_redis.sh

@ -0,0 +1 @@
nohup python redis_check_uuid_mistral.py > myout.redis_check_uuid_mistral.logs 2>&1 &

122
vllm_predict_batch.py

@ -0,0 +1,122 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from flask import Flask, jsonify
from flask import request
from transformers import pipeline
import redis
import uuid
import json
from threading import Thread
from vllm import LLM, SamplingParams
import time
import threading
import time
import concurrent.futures
import requests
import socket
def get_host_ip():
"""
查询本机ip地址
:return: ip
"""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
finally:
s.close()
return ip
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=11, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_query_articles_directory = 'query_articles_directory'
db_key_result = 'result'
batch_size = 32
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>", max_tokens=4096)
models_path = "/home/majiahui/project/models-llm/openbuddy-mistral-7b-v13.1"
llm = LLM(model=models_path, tokenizer_mode="slow")
def dialog_line_parse(url, text):
"""
将数据输入模型进行分析并输出结果
:param url: 模型url
:param text: 进入模型的数据
:return: 模型返回结果
"""
response = requests.post(
url,
json=text,
timeout=1000
)
if response.status_code == 200:
return response.json()
else:
# logger.error(
# "【{}】 Failed to get a proper response from remote "
# "server. Status Code: {}. Response: {}"
# "".format(url, response.status_code, response.text)
# )
print("{}】 Failed to get a proper response from remote "
"server. Status Code: {}. Response: {}"
"".format(url, response.status_code, response.text))
print(text)
return []
def classify(batch_size): # 调用模型,设置最大batch_size
while True:
texts = []
query_ids = []
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(2)
continue
for i in range(min(redis_.llen(db_key_query), batch_size)):
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
query_ids.append(json.loads(query)['id'])
texts.append(json.loads(query)['text']) # 拼接若干text 为batch
outputs = llm.generate(texts, sampling_params) # 调用模型
generated_text_list = [""] * len(texts)
print("outputs", outputs)
for i, output in enumerate(outputs):
index = output.request_id
generated_text = output.outputs[0].text
generated_text_list[int(index)] = generated_text
for (id_, output) in zip(query_ids, generated_text_list):
res = output
redis_.set(id_, json.dumps(res)) # 将模型结果送回队列
@app.route("/predict", methods=["POST"])
def handle_query():
text = request.json["texts"] # 获取用户query中的文本 例如"I love you"
id_ = str(uuid.uuid1()) # 为query生成唯一标识
d = {'id': id_, 'text': text} # 绑定文本和query id
redis_.rpush(db_key_query, json.dumps(d)) # 加入redis
while True:
result = redis_.get(id_) # 获取该query的模型结果
if result is not None:
redis_.delete(id_)
result_text = {'code': "200", 'data': json.loads(result)}
break
time.sleep(1)
return jsonify(result_text) # 返回结果
t = Thread(target=classify, args=(batch_size,))
t.start()
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=18001)
Loading…
Cancel
Save