diff --git a/flask_predict_mistral_vllm.py b/flask_predict_mistral_vllm.py index 61e5760..09a9b6a 100644 --- a/flask_predict_mistral_vllm.py +++ b/flask_predict_mistral_vllm.py @@ -17,6 +17,7 @@ redis_ = redis.Redis(connection_pool=pool, decode_responses=True) db_key_query = 'query' db_key_result = 'result' +batch_size = 64 sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=1.1,stop="", max_tokens=4096) models_path = "/home/majiahui/model-llm/openbuddy-mistral-7b-v13.1" @@ -41,10 +42,14 @@ def classify(): # 调用模型,设置最大batch_size while True: if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 continue - else: + # else: + # query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + # query_ids = json.loads(query)['id'] + # texts = json.loads(query)['texts'] # 拼接若干text 为batch + for i in range(min(redis_.llen(db_key_query), batch_size)): query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text - query_ids = json.loads(query)['id'] - texts = json.loads(query)['texts'] # 拼接若干text 为batch + query_ids.append(json.loads(query)['id']) + texts.append(json.loads(query)['text']) # 拼接若干text 为batch result = mistral_vllm_models(texts) # 调用模型 print(result) redis_.set(query_ids, json.dumps(result)) # 将模型结果送回队列