Browse Source

第一个增强版上线

master
majiahui@haimaqingfan.com 2 years ago
parent
commit
f90fd82462
  1. 14
      flask_drop_rewrite_request.py
  2. 3
      flask_predict_batch_mistral.py
  3. 30
      flask_predict_mistral_vllm.py
  4. 5
      redis_check_uuid_mistral.py

14
flask_drop_rewrite_request.py

@ -302,7 +302,7 @@ def pre_sentence_ulit(sentence):
''' '''
sentence = str(sentence).strip() sentence = str(sentence).strip()
if_change = True if_change = True
if len(sentence) > 7: if len(sentence) > 9:
text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(sentence) text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(sentence)
else: else:
text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(sentence) text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(sentence)
@ -322,6 +322,8 @@ def pre_sentence_ulit(sentence):
def main(texts: dict): def main(texts: dict):
if texts == {"1": "0"}:
9/0
text_list = paragraph_test(texts) text_list = paragraph_test(texts)
text_info = [] text_info = []
@ -421,11 +423,17 @@ def classify(): # 调用模型,设置最大batch_size
if text_type == 'focus': if text_type == 'focus':
texts_list = main(texts) texts_list = main(texts)
elif text_type == 'chapter': elif text_type == 'chapter':
texts_list = main(texts) try:
texts_list = main(texts)
except:
texts_list = []
else: else:
texts_list = [] texts_list = []
if texts_list != []:
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
else:
return_text = {"texts": texts_list, "probabilities": None, "status_code": 400}
return_text = {"texts": texts_list, "probabilities": None, "status_code": 200}
load_result_path = "./new_data_logs/{}.json".format(query_id) load_result_path = "./new_data_logs/{}.json".format(query_id)
print("query_id: ", query_id) print("query_id: ", query_id)

3
flask_predict_batch_mistral.py

@ -279,13 +279,12 @@ def main(texts: dict):
for i, output in enumerate(outputs): for i, output in enumerate(outputs):
index = output.request_id index = output.request_id
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
generated_text = pre_sentence_ulit(generated_text)
generated_text_list[int(index)] = generated_text generated_text_list[int(index)] = generated_text
for i in range(len(text_list)): for i in range(len(text_list)):
if len(text_list[i][0]) > 7: if len(text_list[i][0]) > 7:
continue generated_text_list[i] = pre_sentence_ulit(generated_text_list[i])
else: else:
generated_text_list[i] = text_list[i][0] generated_text_list[i] = text_list[i][0]

30
flask_predict_mistral_vllm.py

@ -1,3 +1,5 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import flask import flask
from transformers import pipeline from transformers import pipeline
import redis import redis
@ -5,6 +7,9 @@ import uuid
import json import json
from threading import Thread from threading import Thread
import time import time
import requests
from flask import request
from vllm import LLM, SamplingParams
app = flask.Flask(__name__) app = flask.Flask(__name__)
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=5, password="zhicheng123*") pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=5, password="zhicheng123*")
@ -29,34 +34,33 @@ def mistral_vllm_models(texts):
generated_text = output.outputs[0].text generated_text = output.outputs[0].text
generated_text_list[int(index)] = generated_text generated_text_list[int(index)] = generated_text
return generated_text_list
def classify(batch_size): # 调用模型,设置最大batch_size
def classify(): # 调用模型,设置最大batch_size
while True: while True:
texts = []
query_ids = []
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
continue continue
for i in range(min(redis_.llen(db_key_query), batch_size)): else:
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
query_ids.append(json.loads(query)['id']) query_ids = json.loads(query)['id']
texts.append(json.loads(query)['text']) # 拼接若干text 为batch texts = json.loads(query)['texts'] # 拼接若干text 为batch
result = mistral_vllm_models(texts) # 调用模型 result = mistral_vllm_models(texts) # 调用模型
for (id_, res) in zip(query_ids, result): print(result)
res['score'] = str(res['score']) redis_.set(query_ids, json.dumps(result)) # 将模型结果送回队列
redis_.set(id_, json.dumps(res)) # 将模型结果送回队列
@app.route("/predict", methods=["POST"]) @app.route("/predict", methods=["POST"])
def handle_query(): def handle_query():
text = flask.request.form['text'] # 获取用户query中的文本 例如"I love you" texts = request.json["texts"] # 获取用户query中的文本 例如"I love you"
id_ = str(uuid.uuid1()) # 为query生成唯一标识 id_ = str(uuid.uuid1()) # 为query生成唯一标识
d = {'id': id_, 'text': text} # 绑定文本和query id d = {'id': id_, 'texts': texts} # 绑定文本和query id
redis_.rpush(db_key_query, json.dumps(d)) # 加入redis redis_.rpush(db_key_query, json.dumps(d)) # 加入redis
while True: while True:
result = redis_.get(id_) # 获取该query的模型结果 result = redis_.get(id_) # 获取该query的模型结果
if result is not None: if result is not None:
redis_.delete(id_) redis_.delete(id_)
result_text = {'code': "200", 'data': result.decode('UTF-8')} result_text = {'code': "200", 'resilt': json.loads(result.decode('UTF-8'))}
break break
return flask.jsonify(result_text) # 返回结果 return flask.jsonify(result_text) # 返回结果
@ -64,4 +68,4 @@ def handle_query():
if __name__ == "__main__": if __name__ == "__main__":
t = Thread(target=classify) t = Thread(target=classify)
t.start() t.start()
app.run(debug=False, host='127.0.0.1', port=9000) app.run(debug=False, host='0.0.0.0', port=14010)

5
redis_check_uuid_mistral.py

@ -44,9 +44,10 @@ def handle_query():
with open(result_path, encoding='utf8') as f1: with open(result_path, encoding='utf8') as f1:
# 加载文件的对象 # 加载文件的对象
result_dict = json.load(f1) result_dict = json.load(f1)
code = result_dict["status_code"]
texts = result_dict["texts"] texts = result_dict["texts"]
probabilities = result_dict["probabilities"] probabilities = result_dict["probabilities"]
result_text = {'code': 200, 'text': texts, 'probabilities': probabilities} result_text = {'code': code, 'text': texts, 'probabilities': probabilities}
else: else:
querying_list = list(redis_.smembers("querying")) querying_list = list(redis_.smembers("querying"))
querying_set = set() querying_set = set()
@ -84,4 +85,4 @@ def handle_query():
if __name__ == "__main__": if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=14003) app.run(debug=False, host='0.0.0.0', port=14005)

Loading…
Cancel
Save