rag知识库问答
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

169 lines
5.9 KiB

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging
from threading import Thread
import requests
import redis
import uuid
import time
import json
from flask import Flask, jsonify
from flask import request
import deepspeed
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=17, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
# model config
model_name = "/home/majiahui/project/models-llm/QwQ-32"
ds_config = {
"dtype": "fp16",
"tensor_parallel": {
"tp_size": 4
}
}
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = deepspeed.init_inference(model, config=ds_config)
db_key_query = 'query'
db_key_querying = 'querying'
db_key_result = 'result'
batch_size = 15
def main(prompt):
# prompt = "电视剧《人世间》导演和演员是谁"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
output = model.generate(**inputs)
response = tokenizer.decode(output[0])
return response
def classify(): # 调用模型,设置最大batch_size
while True:
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(2)
continue
query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text
data_dict_path = json.loads(query)
path = data_dict_path['path']
with open(path, encoding='utf8') as f1:
# 加载文件的对象
data_dict = json.load(f1)
query_id = data_dict['id']
input_text = data_dict["text"]
output_text = main(input_text)
return_text = {
"input": input_text,
"output": output_text
}
load_result_path = "./new_data_logs/{}.json".format(query_id)
print("query_id: ", query_id)
print("load_result_path: ", load_result_path)
with open(load_result_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(return_text, f2, ensure_ascii=False, indent=4)
redis_.set(query_id, load_result_path, 86400)
# log.log('start at',
# 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format(
# query_id, load_result_path, return_text)
@app.route("/predict", methods=["POST"])
def predict():
print(request.remote_addr)
content = request.json["content"]
id_ = str(uuid.uuid1()) # 为query生成唯一标识
print("uuid: ", id_)
d = {'id': id_, 'text': content} # 绑定文本和query id
load_request_path = './request_data_logs/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(d, f2, ensure_ascii=False, indent=4)
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
redis_.sadd(db_key_querying, id_)
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
print("ok")
return jsonify(return_text) # 返回结果
@app.route("/search", methods=["POST"])
def search():
id_ = request.json['id'] # 获取用户query中的文本 例如"I love you"
result = redis_.get(id_) # 获取该query的模型结果
if result is not None:
# redis_.delete(id_)
result_path = result.decode('UTF-8')
with open(result_path, encoding='utf8') as f1:
# 加载文件的对象
result_dict = json.load(f1)
texts = result_dict["output"]
result_text = {'code': 200, 'text': texts, 'probabilities': None}
else:
querying_list = list(redis_.smembers("querying"))
querying_set = set()
for i in querying_list:
querying_set.add(i.decode())
querying_bool = False
if id_ in querying_set:
querying_bool = True
query_list_json = redis_.lrange(db_key_query, 0, -1)
query_set_ids = set()
for i in query_list_json:
data_dict = json.loads(i)
query_id = data_dict['id']
query_set_ids.add(query_id)
query_bool = False
if id_ in query_set_ids:
query_bool = True
if querying_bool == True and query_bool == True:
result_text = {'code': "201", 'text': "", 'probabilities': None}
elif querying_bool == True and query_bool == False:
result_text = {'code': "202", 'text': "", 'probabilities': None}
else:
result_text = {'code': "203", 'text': "", 'probabilities': None}
load_request_path = './request_data_logs_203/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(result_text, f2, ensure_ascii=False, indent=4)
return jsonify(result_text) # 返回结果
t = Thread(target=classify)
t.start()
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别
filename='rewrite.log',
filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志
# a是追加模式,默认如果不写的话,就是追加模式
format=
'%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s'
# 日志格式
)
app.run(host="0.0.0.0", port=26000, threaded=True, debug=False)