import os os.environ["CUDA_VISIBLE_DEVICES"] = "1" import argparse from typing import List, Tuple from threading import Thread from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams # from vllm.utils import FlexibleArgumentParser from flask import Flask, jsonify from flask import request import redis import time import json # http接口服务 # app = FastAPI() app = Flask(__name__) app.config["JSON_AS_ASCII"] = False pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=3, password="zhicheng123*") redis_ = redis.Redis(connection_pool=pool, decode_responses=True) db_key_query = 'query' db_key_querying = 'querying' db_key_result = 'result' batch_size = 15 class log: def __init__(self): pass def log(*args, **kwargs): format = '%Y/%m/%d-%H:%M:%S' format_h = '%Y-%m-%d' value = time.localtime(int(time.time())) dt = time.strftime(format, value) dt_log_file = time.strftime(format_h, value) log_file = 'log_file/access-%s' % dt_log_file + ".log" if not os.path.exists(log_file): with open(os.path.join(log_file), 'w', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) else: with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: print(dt, *args, file=f, **kwargs) def initialize_engine() -> LLMEngine: """Initialize the LLMEngine from the command line arguments.""" # model_dir = "/home/majiahui/project/models-llm/Qwen-0_5B-Chat" # model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper" model_dir = "/home/majiahui/project/models-llm/openbuddy-llama3.1-8b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper_1" args = EngineArgs(model_dir) args.max_num_seqs = 16 # batch最大20条样本 args.gpu_memory_utilization = 0.8 args.max_model_len=8192 # 加载模型 return LLMEngine.from_engine_args(args) engine = initialize_engine() def create_test_prompts(prompt_texts, query_ids, sampling_params_list) -> List[Tuple[str,str, SamplingParams]]: """Create a list of test prompts with their sampling parameters.""" return_list = [] for i,j,k in zip(prompt_texts, query_ids, sampling_params_list): return_list.append((i, j, k)) return return_list def process_requests(engine: LLMEngine, test_prompts: List[Tuple[str, str, SamplingParams]]): """Continuously process a list of prompts and handle the outputs.""" return_list = [] while test_prompts or engine.has_unfinished_requests(): if test_prompts: prompt, query_id, sampling_params = test_prompts.pop(0) engine.add_request(str(query_id), prompt, sampling_params) request_outputs: List[RequestOutput] = engine.step() for request_output in request_outputs: if request_output.finished: return_list.append(request_output) return return_list def main(prompt_texts, query_ids, sampling_params_list): """Main function that sets up and runs the prompt processing.""" test_prompts = create_test_prompts(prompt_texts, query_ids, sampling_params_list) return process_requests(engine, test_prompts) # chat对话接口 # @app.route("/predict/", methods=["POST"]) # def chat(): # # request = request.json() # # query = request.get('query', None) # # history = request.get('history', []) # # system = request.get('system', 'You are a helpful assistant.') # # stream = request.get("stream", False) # # user_stop_words = request.get("user_stop_words", # # []) # list[str],用户自定义停止句,例如:['Observation: ', 'Action: ']定义了2个停止句,遇到任何一个都会停止 # # query = request.json['query'] # # # # 构造prompt # # prompt_text, prompt_tokens = _build_prompt(generation_config, tokenizer, query, history=history, system=system) # # prompt_text = f"<|im_start|>user\n{query}\n<|im_end|>\n<|im_start|>assistant\n" # # # return_output = main(prompt_text, sampling_params) # return_info = { # "request_id": return_output.request_id, # "text": return_output.outputs[0].text # } # # return jsonify(return_info) def classify(batch_size): # 调用模型,设置最大batch_size while True: texts = [] query_ids = [] sampling_params_list = [] if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 time.sleep(2) continue # for i in range(min(redis_.llen(db_key_query), batch_size)): while True: query = redis_.lpop(db_key_query) # 获取query的text if query == None: break query = query.decode('UTF-8') data_dict_path = json.loads(query) path = data_dict_path['path'] with open(path, encoding='utf8') as f1: # 加载文件的对象 data_dict = json.load(f1) # query_ids.append(json.loads(query)['id']) # texts.append(json.loads(query)['text']) # 拼接若干text 为batch query_id = data_dict['id'] print("query_id", query_id) text = data_dict["text"] model = data_dict["model"] top_p = data_dict["top_p"] temperature = data_dict["temperature"] presence_penalty = 0.8 max_tokens = 8192 query_ids.append(query_id) texts.append(text) # sampling_params = SamplingParams(temperature=0.3, top_p=0.5, stop="<|end|>", presence_penalty=1.1, max_tokens=8192) sampling_params_list.append(SamplingParams( temperature=temperature, top_p=top_p, stop="<|end|>", presence_penalty=presence_penalty, max_tokens=max_tokens )) if len(texts) == batch_size: break print("texts", len(texts)) print("query_ids", len(query_ids)) print("sampling_params_list", len(sampling_params_list)) outputs = main(texts, query_ids, sampling_params_list) print("预测完成") generated_text_dict = {} print("outputs", len(outputs)) for i, output in enumerate(outputs): index = output.request_id print(index) generated_text = output.outputs[0].text generated_text_dict[index] = generated_text print(generated_text_dict) for id_, output in generated_text_dict.items(): return_text = {"texts": output, "probabilities": None, "status_code": 200} load_result_path = "./new_data_logs/{}.json".format(id_) with open(load_result_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(return_text, f2, ensure_ascii=False, indent=4) redis_.set(id_, load_result_path, 86400) # redis_.set(id_, load_result_path, 30) redis_.srem(db_key_querying, id_) log.log('start at', 'query_id:{},load_result_path:{},return_text:{}'.format( id_, load_result_path, return_text)) if __name__ == '__main__': t = Thread(target=classify, args=(batch_size,)) t.start()