import os os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4" from transformers import AutoModelForCausalLM, AutoTokenizer import logging from threading import Thread import requests import redis import uuid import time import json from flask import Flask, jsonify from flask import request import deepspeed app = Flask(__name__) app.config["JSON_AS_ASCII"] = False pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=17, password="zhicheng123*") redis_ = redis.Redis(connection_pool=pool, decode_responses=True) # model config model_name = "/home/majiahui/project/models-llm/QwQ-32" ds_config = { "dtype": "fp16", "tensor_parallel": { "tp_size": 4 } } model = AutoModelForCausalLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) model = deepspeed.init_inference(model, config=ds_config) db_key_query = 'query' db_key_querying = 'querying' db_key_result = 'result' batch_size = 15 def main(prompt): # prompt = "电视剧《人世间》导演和演员是谁" inputs = tokenizer(prompt, return_tensors="pt").to("cuda") output = model.generate(**inputs) response = tokenizer.decode(output[0]) return response def classify(): # 调用模型,设置最大batch_size while True: if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 time.sleep(2) continue query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text data_dict_path = json.loads(query) path = data_dict_path['path'] with open(path, encoding='utf8') as f1: # 加载文件的对象 data_dict = json.load(f1) query_id = data_dict['id'] input_text = data_dict["text"] output_text = main(input_text) return_text = { "input": input_text, "output": output_text } load_result_path = "./new_data_logs/{}.json".format(query_id) print("query_id: ", query_id) print("load_result_path: ", load_result_path) with open(load_result_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(return_text, f2, ensure_ascii=False, indent=4) redis_.set(query_id, load_result_path, 86400) # log.log('start at', # 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( # query_id, load_result_path, return_text) @app.route("/predict", methods=["POST"]) def predict(): print(request.remote_addr) content = request.json["content"] id_ = str(uuid.uuid1()) # 为query生成唯一标识 print("uuid: ", id_) d = {'id': id_, 'text': content} # 绑定文本和query id load_request_path = './request_data_logs/{}.json'.format(id_) with open(load_request_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(d, f2, ensure_ascii=False, indent=4) redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis redis_.sadd(db_key_querying, id_) return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200} print("ok") return jsonify(return_text) # 返回结果 @app.route("/search", methods=["POST"]) def search(): id_ = request.json['id'] # 获取用户query中的文本 例如"I love you" result = redis_.get(id_) # 获取该query的模型结果 if result is not None: # redis_.delete(id_) result_path = result.decode('UTF-8') with open(result_path, encoding='utf8') as f1: # 加载文件的对象 result_dict = json.load(f1) texts = result_dict["output"] result_text = {'code': 200, 'text': texts, 'probabilities': None} else: querying_list = list(redis_.smembers("querying")) querying_set = set() for i in querying_list: querying_set.add(i.decode()) querying_bool = False if id_ in querying_set: querying_bool = True query_list_json = redis_.lrange(db_key_query, 0, -1) query_set_ids = set() for i in query_list_json: data_dict = json.loads(i) query_id = data_dict['id'] query_set_ids.add(query_id) query_bool = False if id_ in query_set_ids: query_bool = True if querying_bool == True and query_bool == True: result_text = {'code': "201", 'text': "", 'probabilities': None} elif querying_bool == True and query_bool == False: result_text = {'code': "202", 'text': "", 'probabilities': None} else: result_text = {'code': "203", 'text': "", 'probabilities': None} load_request_path = './request_data_logs_203/{}.json'.format(id_) with open(load_request_path, 'w', encoding='utf8') as f2: # ensure_ascii=False才能输入中文,否则是Unicode字符 # indent=2 JSON数据的缩进,美观 json.dump(result_text, f2, ensure_ascii=False, indent=4) return jsonify(result_text) # 返回结果 t = Thread(target=classify) t.start() if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 filename='rewrite.log', filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 # a是追加模式,默认如果不写的话,就是追加模式 format= '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' # 日志格式 ) app.run(host="0.0.0.0", port=26000, threaded=True, debug=False)