You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
180 lines
6.3 KiB
180 lines
6.3 KiB
![]()
2 months ago
|
import os
|
||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"
|
||
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
import logging
|
||
|
from threading import Thread
|
||
|
import requests
|
||
|
import redis
|
||
|
import uuid
|
||
|
import time
|
||
|
import json
|
||
|
from flask import Flask, jsonify
|
||
|
from flask import request
|
||
|
|
||
|
app = Flask(__name__)
|
||
|
app.config["JSON_AS_ASCII"] = False
|
||
|
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=17, password="zhicheng123*")
|
||
|
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
|
||
|
|
||
|
# model config
|
||
|
model_name = "/home/majiahui/project/models-llm/QwQ-32"
|
||
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
model_name,
|
||
|
torch_dtype="auto",
|
||
|
device_map="auto"
|
||
|
)
|
||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||
|
|
||
|
db_key_query = 'query'
|
||
|
db_key_querying = 'querying'
|
||
|
db_key_result = 'result'
|
||
|
batch_size = 15
|
||
|
|
||
|
|
||
|
def main(prompt):
|
||
|
# prompt = "电视剧《人世间》导演和演员是谁"
|
||
|
messages = [
|
||
|
{"role": "user", "content": prompt}
|
||
|
]
|
||
|
text = tokenizer.apply_chat_template(
|
||
|
messages,
|
||
|
tokenize=False,
|
||
|
add_generation_prompt=True
|
||
|
)
|
||
|
|
||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||
|
|
||
|
generated_ids = model.generate(
|
||
|
**model_inputs,
|
||
|
max_new_tokens=32768
|
||
|
)
|
||
|
generated_ids = [
|
||
|
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
|
||
|
]
|
||
|
|
||
|
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||
|
response = str(response).split("</think>")[1]
|
||
|
return response
|
||
|
|
||
|
|
||
|
def classify(): # 调用模型,设置最大batch_size
|
||
|
while True:
|
||
|
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
|
||
|
time.sleep(2)
|
||
|
continue
|
||
|
|
||
|
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
|
||
|
data_dict_path = json.loads(query)
|
||
|
path = data_dict_path['path']
|
||
|
|
||
|
with open(path, encoding='utf8') as f1:
|
||
|
# 加载文件的对象
|
||
|
data_dict = json.load(f1)
|
||
|
|
||
|
query_id = data_dict['id']
|
||
|
input_text = data_dict["text"]
|
||
|
|
||
|
output_text = main(input_text)
|
||
|
|
||
|
return_text = {
|
||
|
"input": input_text,
|
||
|
"output": output_text
|
||
|
}
|
||
|
|
||
|
load_result_path = "./new_data_logs/{}.json".format(query_id)
|
||
|
print("query_id: ", query_id)
|
||
|
print("load_result_path: ", load_result_path)
|
||
|
|
||
|
with open(load_result_path, 'w', encoding='utf8') as f2:
|
||
|
# ensure_ascii=False才能输入中文,否则是Unicode字符
|
||
|
# indent=2 JSON数据的缩进,美观
|
||
|
json.dump(return_text, f2, ensure_ascii=False, indent=4)
|
||
|
redis_.set(query_id, load_result_path, 86400)
|
||
|
# log.log('start at',
|
||
|
# 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format(
|
||
|
# query_id, load_result_path, return_text)
|
||
|
|
||
|
|
||
|
@app.route("/predict", methods=["POST"])
|
||
|
def predict():
|
||
|
print(request.remote_addr)
|
||
|
content = request.json["content"]
|
||
|
|
||
|
id_ = str(uuid.uuid1()) # 为query生成唯一标识
|
||
|
print("uuid: ", id_)
|
||
|
d = {'id': id_, 'text': content} # 绑定文本和query id
|
||
|
|
||
|
load_request_path = './request_data_logs/{}.json'.format(id_)
|
||
|
with open(load_request_path, 'w', encoding='utf8') as f2:
|
||
|
# ensure_ascii=False才能输入中文,否则是Unicode字符
|
||
|
# indent=2 JSON数据的缩进,美观
|
||
|
json.dump(d, f2, ensure_ascii=False, indent=4)
|
||
|
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
|
||
|
redis_.sadd(db_key_querying, id_)
|
||
|
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
|
||
|
print("ok")
|
||
|
return jsonify(return_text) # 返回结果
|
||
|
|
||
|
|
||
|
@app.route("/search", methods=["POST"])
|
||
|
def search():
|
||
|
id_ = request.json['id'] # 获取用户query中的文本 例如"I love you"
|
||
|
result = redis_.get(id_) # 获取该query的模型结果
|
||
|
if result is not None:
|
||
|
# redis_.delete(id_)
|
||
|
result_path = result.decode('UTF-8')
|
||
|
with open(result_path, encoding='utf8') as f1:
|
||
|
# 加载文件的对象
|
||
|
result_dict = json.load(f1)
|
||
|
texts = result_dict["output"]
|
||
|
result_text = {'code': 200, 'text': texts, 'probabilities': None}
|
||
|
else:
|
||
|
querying_list = list(redis_.smembers("querying"))
|
||
|
querying_set = set()
|
||
|
for i in querying_list:
|
||
|
querying_set.add(i.decode())
|
||
|
|
||
|
querying_bool = False
|
||
|
if id_ in querying_set:
|
||
|
querying_bool = True
|
||
|
|
||
|
query_list_json = redis_.lrange(db_key_query, 0, -1)
|
||
|
query_set_ids = set()
|
||
|
for i in query_list_json:
|
||
|
data_dict = json.loads(i)
|
||
|
query_id = data_dict['id']
|
||
|
query_set_ids.add(query_id)
|
||
|
|
||
|
query_bool = False
|
||
|
if id_ in query_set_ids:
|
||
|
query_bool = True
|
||
|
|
||
|
if querying_bool == True and query_bool == True:
|
||
|
result_text = {'code': "201", 'text': "", 'probabilities': None}
|
||
|
elif querying_bool == True and query_bool == False:
|
||
|
result_text = {'code': "202", 'text': "", 'probabilities': None}
|
||
|
else:
|
||
|
result_text = {'code': "203", 'text': "", 'probabilities': None}
|
||
|
load_request_path = './request_data_logs_203/{}.json'.format(id_)
|
||
|
with open(load_request_path, 'w', encoding='utf8') as f2:
|
||
|
# ensure_ascii=False才能输入中文,否则是Unicode字符
|
||
|
# indent=2 JSON数据的缩进,美观
|
||
|
json.dump(result_text, f2, ensure_ascii=False, indent=4)
|
||
|
|
||
|
return jsonify(result_text) # 返回结果
|
||
|
|
||
|
|
||
|
t = Thread(target=classify)
|
||
|
t.start()
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
|
||
|
filename='rewrite.log',
|
||
|
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
|
||
|
# a是追加模式,默认如果不写的话,就是追加模式
|
||
|
format=
|
||
|
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
|
||
|
# 日志格式
|
||
|
)
|
||
|
app.run(host="0.0.0.0", port=26000, threaded=True, debug=False)
|