From f90fd82462640471fee072940a135db50fde1aba Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Tue, 31 Oct 2023 15:16:43 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E4=B8=AA=E5=A2=9E=E5=BC=BA?= =?UTF-8?q?=E7=89=88=E4=B8=8A=E7=BA=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- flask_drop_rewrite_request.py | 14 +++++++++++--- flask_predict_batch_mistral.py | 3 +-- flask_predict_mistral_vllm.py | 30 +++++++++++++++++------------- redis_check_uuid_mistral.py | 5 +++-- 4 files changed, 32 insertions(+), 20 deletions(-) diff --git a/flask_drop_rewrite_request.py b/flask_drop_rewrite_request.py index 938e50d..823a0eb 100644 --- a/flask_drop_rewrite_request.py +++ b/flask_drop_rewrite_request.py @@ -302,7 +302,7 @@ def pre_sentence_ulit(sentence): ''' sentence = str(sentence).strip() if_change = True - if len(sentence) > 7: + if len(sentence) > 9: text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(sentence) else: text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(sentence) @@ -322,6 +322,8 @@ def pre_sentence_ulit(sentence): def main(texts: dict): + if texts == {"1": "0"}: + 9/0 text_list = paragraph_test(texts) text_info = [] @@ -421,11 +423,17 @@ def classify(): # 调用模型,设置最大batch_size if text_type == 'focus': texts_list = main(texts) elif text_type == 'chapter': - texts_list = main(texts) + try: + texts_list = main(texts) + except: + texts_list = [] else: texts_list = [] + if texts_list != []: + return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} + else: + return_text = {"texts": texts_list, "probabilities": None, "status_code": 400} - return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} load_result_path = "./new_data_logs/{}.json".format(query_id) print("query_id: ", query_id) diff --git a/flask_predict_batch_mistral.py b/flask_predict_batch_mistral.py index b60e29a..e23e26c 100644 --- a/flask_predict_batch_mistral.py +++ b/flask_predict_batch_mistral.py @@ -279,13 +279,12 @@ def main(texts: dict): for i, output in enumerate(outputs): index = output.request_id generated_text = output.outputs[0].text - generated_text = pre_sentence_ulit(generated_text) generated_text_list[int(index)] = generated_text for i in range(len(text_list)): if len(text_list[i][0]) > 7: - continue + generated_text_list[i] = pre_sentence_ulit(generated_text_list[i]) else: generated_text_list[i] = text_list[i][0] diff --git a/flask_predict_mistral_vllm.py b/flask_predict_mistral_vllm.py index 2367ae0..61e5760 100644 --- a/flask_predict_mistral_vllm.py +++ b/flask_predict_mistral_vllm.py @@ -1,3 +1,5 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "2" import flask from transformers import pipeline import redis @@ -5,6 +7,9 @@ import uuid import json from threading import Thread import time +import requests +from flask import request +from vllm import LLM, SamplingParams app = flask.Flask(__name__) pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=5, password="zhicheng123*") @@ -29,34 +34,33 @@ def mistral_vllm_models(texts): generated_text = output.outputs[0].text generated_text_list[int(index)] = generated_text + return generated_text_list -def classify(batch_size): # 调用模型,设置最大batch_size + +def classify(): # 调用模型,设置最大batch_size while True: - texts = [] - query_ids = [] if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 continue - for i in range(min(redis_.llen(db_key_query), batch_size)): + else: query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text - query_ids.append(json.loads(query)['id']) - texts.append(json.loads(query)['text']) # 拼接若干text 为batch + query_ids = json.loads(query)['id'] + texts = json.loads(query)['texts'] # 拼接若干text 为batch result = mistral_vllm_models(texts) # 调用模型 - for (id_, res) in zip(query_ids, result): - res['score'] = str(res['score']) - redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 + print(result) + redis_.set(query_ids, json.dumps(result)) # 将模型结果送回队列 @app.route("/predict", methods=["POST"]) def handle_query(): - text = flask.request.form['text'] # 获取用户query中的文本 例如"I love you" + texts = request.json["texts"] # 获取用户query中的文本 例如"I love you" id_ = str(uuid.uuid1()) # 为query生成唯一标识 - d = {'id': id_, 'text': text} # 绑定文本和query id + d = {'id': id_, 'texts': texts} # 绑定文本和query id redis_.rpush(db_key_query, json.dumps(d)) # 加入redis while True: result = redis_.get(id_) # 获取该query的模型结果 if result is not None: redis_.delete(id_) - result_text = {'code': "200", 'data': result.decode('UTF-8')} + result_text = {'code': "200", 'resilt': json.loads(result.decode('UTF-8'))} break return flask.jsonify(result_text) # 返回结果 @@ -64,4 +68,4 @@ def handle_query(): if __name__ == "__main__": t = Thread(target=classify) t.start() - app.run(debug=False, host='127.0.0.1', port=9000) \ No newline at end of file + app.run(debug=False, host='0.0.0.0', port=14010) \ No newline at end of file diff --git a/redis_check_uuid_mistral.py b/redis_check_uuid_mistral.py index 863ad97..0574e86 100644 --- a/redis_check_uuid_mistral.py +++ b/redis_check_uuid_mistral.py @@ -44,9 +44,10 @@ def handle_query(): with open(result_path, encoding='utf8') as f1: # 加载文件的对象 result_dict = json.load(f1) + code = result_dict["status_code"] texts = result_dict["texts"] probabilities = result_dict["probabilities"] - result_text = {'code': 200, 'text': texts, 'probabilities': probabilities} + result_text = {'code': code, 'text': texts, 'probabilities': probabilities} else: querying_list = list(redis_.smembers("querying")) querying_set = set() @@ -84,4 +85,4 @@ def handle_query(): if __name__ == "__main__": - app.run(debug=False, host='0.0.0.0', port=14003) + app.run(debug=False, host='0.0.0.0', port=14005)