Browse Source

第一个增强版上线

master
majiahui@haimaqingfan.com 2 years ago
parent
commit
f90fd82462
  1. 14
      flask_drop_rewrite_request.py
  2. 3
      flask_predict_batch_mistral.py
  3. 30
      flask_predict_mistral_vllm.py
  4. 5
      redis_check_uuid_mistral.py

14
flask_drop_rewrite_request.py

@ -302,7 +302,7 @@ def pre_sentence_ulit(sentence):
'''
sentence = str(sentence).strip()
if_change = True
if len(sentence) > 7:
if len(sentence) > 9:
text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(sentence)
else:
text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(sentence)
@ -322,6 +322,8 @@ def pre_sentence_ulit(sentence):
def main(texts: dict):
if texts == {"1": "0"}:
9/0
text_list = paragraph_test(texts)
text_info = []
@ -421,11 +423,17 @@ def classify(): # 调用模型,设置最大batch_size
if text_type == 'focus':
texts_list = main(texts)
elif text_type == 'chapter':
texts_list = main(texts)
try:
texts_list = main(texts)
except:
texts_list = []
else:
texts_list = []
if texts_list != []:
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
else:
return_text = {"texts": texts_list, "probabilities": None, "status_code": 400}
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
load_result_path = "./new_data_logs/{}.json".format(query_id)
print("query_id: ", query_id)

3
flask_predict_batch_mistral.py

@ -279,13 +279,12 @@ def main(texts: dict):
for i, output in enumerate(outputs):
index = output.request_id
generated_text = output.outputs[0].text
generated_text = pre_sentence_ulit(generated_text)
generated_text_list[int(index)] = generated_text
for i in range(len(text_list)):
if len(text_list[i][0]) > 7:
continue
generated_text_list[i] = pre_sentence_ulit(generated_text_list[i])
else:
generated_text_list[i] = text_list[i][0]

30
flask_predict_mistral_vllm.py

@ -1,3 +1,5 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import flask
from transformers import pipeline
import redis
@ -5,6 +7,9 @@ import uuid
import json
from threading import Thread
import time
import requests
from flask import request
from vllm import LLM, SamplingParams
app = flask.Flask(__name__)
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=5, password="zhicheng123*")
@ -29,34 +34,33 @@ def mistral_vllm_models(texts):
generated_text = output.outputs[0].text
generated_text_list[int(index)] = generated_text
return generated_text_list
def classify(batch_size): # 调用模型,设置最大batch_size
def classify(): # 调用模型,设置最大batch_size
while True:
texts = []
query_ids = []
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
continue
for i in range(min(redis_.llen(db_key_query), batch_size)):
else:
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
query_ids.append(json.loads(query)['id'])
texts.append(json.loads(query)['text']) # 拼接若干text 为batch
query_ids = json.loads(query)['id']
texts = json.loads(query)['texts'] # 拼接若干text 为batch
result = mistral_vllm_models(texts) # 调用模型
for (id_, res) in zip(query_ids, result):
res['score'] = str(res['score'])
redis_.set(id_, json.dumps(res)) # 将模型结果送回队列
print(result)
redis_.set(query_ids, json.dumps(result)) # 将模型结果送回队列
@app.route("/predict", methods=["POST"])
def handle_query():
text = flask.request.form['text'] # 获取用户query中的文本 例如"I love you"
texts = request.json["texts"] # 获取用户query中的文本 例如"I love you"
id_ = str(uuid.uuid1()) # 为query生成唯一标识
d = {'id': id_, 'text': text} # 绑定文本和query id
d = {'id': id_, 'texts': texts} # 绑定文本和query id
redis_.rpush(db_key_query, json.dumps(d)) # 加入redis
while True:
result = redis_.get(id_) # 获取该query的模型结果
if result is not None:
redis_.delete(id_)
result_text = {'code': "200", 'data': result.decode('UTF-8')}
result_text = {'code': "200", 'resilt': json.loads(result.decode('UTF-8'))}
break
return flask.jsonify(result_text) # 返回结果
@ -64,4 +68,4 @@ def handle_query():
if __name__ == "__main__":
t = Thread(target=classify)
t.start()
app.run(debug=False, host='127.0.0.1', port=9000)
app.run(debug=False, host='0.0.0.0', port=14010)

5
redis_check_uuid_mistral.py

@ -44,9 +44,10 @@ def handle_query():
with open(result_path, encoding='utf8') as f1:
# 加载文件的对象
result_dict = json.load(f1)
code = result_dict["status_code"]
texts = result_dict["texts"]
probabilities = result_dict["probabilities"]
result_text = {'code': 200, 'text': texts, 'probabilities': probabilities}
result_text = {'code': code, 'text': texts, 'probabilities': probabilities}
else:
querying_list = list(redis_.smembers("querying"))
querying_set = set()
@ -84,4 +85,4 @@ def handle_query():
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=14003)
app.run(debug=False, host='0.0.0.0', port=14005)

Loading…
Cancel
Save