rag知识库问答
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

303 lines
12 KiB

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"
import argparse
from typing import List, Tuple
from threading import Thread
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from transformers import AutoTokenizer
# from vllm.utils import FlexibleArgumentParser
from flask import Flask, jsonify
from flask import request
import redis
import time
import json
import uuid
# http接口服务
# app = FastAPI()
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=17, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
db_key_queryset = 'queryset'
db_key_result = 'result'
db_key_error = 'error'
batch_size = 2
tokenizer = AutoTokenizer.from_pretrained("/home/majiahui/project/models-llm/QwQ-32")
def smtp_f(name):
# 在下面的代码行中使用断点来调试脚本。
import smtplib
from email.mime.text import MIMEText
from email.header import Header
sender = '838878981@qq.com' # 发送邮箱
receivers = ['838878981@qq.com'] # 接收邮箱
auth_code = "jfqtutaiwrtdbcge" # 授权码
message = MIMEText('基础大模型出现错误,紧急', 'plain', 'utf-8')
message['From'] = Header("Sender<%s>" % sender) # 发送者
message['To'] = Header("Receiver<%s>" % receivers[0]) # 接收者
subject = name
message['Subject'] = Header(subject, 'utf-8')
try:
server = smtplib.SMTP_SSL('smtp.qq.com', 465)
server.login(sender, auth_code)
server.sendmail(sender, receivers, message.as_string())
print("邮件发送成功")
server.close()
except smtplib.SMTPException:
print("Error: 无法发送邮件")
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
# model_dir = "/home/majiahui/project/models-llm/Qwen-0_5B-Chat"
# model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper"
# model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper_2"
# model_dir = "/home/majiahui/project/models-llm/Qwen2.5-7B-Instruct-1M"
# model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b-v23.1-200k"
model_dir = "/home/majiahui/project/models-llm/QwQ-32"
args = EngineArgs(model_dir)
args.max_num_seqs = 2 # batch最大20条样本
# args.gpu_memory_utilization = 0.8
args.tensor_parallel_size = 4
args.max_model_len=8192
# 加载模型
return LLMEngine.from_engine_args(args)
engine = initialize_engine()
def create_test_prompts(prompt_texts, query_ids, sampling_params_list) -> List[Tuple[str,str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return_list = []
for i,j,k in zip(prompt_texts, query_ids, sampling_params_list):
return_list.append((i, j, k))
return return_list
def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
return_list = []
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, query_id, sampling_params = test_prompts.pop(0)
engine.add_request(str(query_id), prompt, sampling_params)
request_outputs: List[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
return_list.append(request_output)
return return_list
def main(prompt_texts, query_ids, sampling_params_list):
"""Main function that sets up and runs the prompt processing."""
test_prompts = create_test_prompts(prompt_texts, query_ids, sampling_params_list)
return process_requests(engine, test_prompts)
def classify(batch_size): # 调用模型,设置最大batch_size
while True:
texts = []
query_ids = []
sampling_params_list = []
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(2)
continue
# for i in range(min(redis_.llen(db_key_query), batch_size)):
while True:
query = redis_.lpop(db_key_query) # 获取query的text
if query == None:
break
query = query.decode('UTF-8')
data_dict_path = json.loads(query)
path = data_dict_path['path']
with open(path, encoding='utf8') as f1:
# 加载文件的对象
data_dict = json.load(f1)
# query_ids.append(json.loads(query)['id'])
# texts.append(json.loads(query)['text']) # 拼接若干text 为batch
query_id = data_dict['id']
print("query_id", query_id)
text = data_dict["text"]
model = data_dict["model"]
top_p = data_dict["top_p"]
temperature = data_dict["temperature"]
presence_penalty = 1.1
max_tokens = 8192
query_ids.append(query_id)
messages = [
{"role": "user", "content": text}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
texts.append(text)
# sampling_params = SamplingParams(temperature=0.3, top_p=0.5, stop="<|end|>", presence_penalty=1.1, max_tokens=8192)
sampling_params_list.append(SamplingParams(
temperature=temperature,
top_p=top_p,
stop="<|im_end|>",
presence_penalty=presence_penalty,
max_tokens=max_tokens
))
if len(texts) == batch_size:
break
print("texts", len(texts))
print("query_ids", len(query_ids))
print("sampling_params_list", len(sampling_params_list))
outputs = main(texts, query_ids, sampling_params_list)
print("预测完成")
generated_text_dict = {}
print("outputs", len(outputs))
for i, output in enumerate(outputs):
index = output.request_id
print(index)
generated_text = output.outputs[0].text
generated_text_dict[index] = generated_text
print(generated_text_dict)
for id_, output in generated_text_dict.items():
return_text = {"texts": output, "probabilities": None, "status_code": 200}
load_result_path = "./new_data_logs/{}.json".format(id_)
with open(load_result_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(return_text, f2, ensure_ascii=False, indent=4)
redis_.set(id_, load_result_path, 86400)
# redis_.set(id_, load_result_path, 30)
redis_.srem(db_key_querying, id_)
log.log('start at',
'query_id:{},load_result_path:{},return_text:{}'.format(
id_, load_result_path, return_text))
@app.route("/predict", methods=["POST"])
def predict():
content = request.json["content"] # 获取用户query中的文本 例如"I love you"
model = request.json["model"]
top_p = request.json["top_p"]
temperature = request.json["temperature"]
id_ = str(uuid.uuid1()) # 为query生成唯一标识
print("uuid: ", uuid)
d = {'id': id_, 'text': content, 'model': model, 'top_p': top_p,'temperature': temperature} # 绑定文本和query id
print(d)
try:
load_request_path = './request_data_logs/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(d, f2, ensure_ascii=False, indent=4)
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
redis_.sadd(db_key_querying, id_)
redis_.sadd(db_key_queryset, id_)
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
except:
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 400}
smtp_f("vllm-main-paper")
return jsonify(return_text) # 返回结果
@app.route("/search", methods=["POST"])
def search():
id_ = request.json['id'] # 获取用户query中的文本 例如"I love you"
result = redis_.get(id_) # 获取该query的模型结果
try:
if result is not None:
result_path = result.decode('UTF-8')
with open(result_path, encoding='utf8') as f1:
# 加载文件的对象
result_dict = json.load(f1)
code = result_dict["status_code"]
texts = result_dict["texts"]
probabilities = result_dict["probabilities"]
if str(code) == 400:
redis_.rpush(db_key_error, json.dumps({"id": id_}))
return False
result_text = {'code': code, 'text': texts, 'probabilities': probabilities}
else:
querying_list = list(redis_.smembers(db_key_querying))
querying_set = set()
for i in querying_list:
querying_set.add(i.decode())
querying_bool = False
if id_ in querying_set:
querying_bool = True
query_list_json = redis_.lrange(db_key_query, 0, -1)
query_set_ids = set()
for i in query_list_json:
data_dict = json.loads(i)
query_id = data_dict['id']
query_set_ids.add(query_id)
query_bool = False
if id_ in query_set_ids:
query_bool = True
if querying_bool == True and query_bool == True:
result_text = {'code': "201", 'text': "", 'probabilities': None}
elif querying_bool == True and query_bool == False:
result_text = {'code': "202", 'text': "", 'probabilities': None}
else:
result_text = {'code': "203", 'text': "", 'probabilities': None}
load_request_path = './request_data_logs_203/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(result_text, f2, ensure_ascii=False, indent=4)
except:
smtp_f("vllm-main")
result_text = {'code': "400", 'text': "", 'probabilities': None}
return jsonify(result_text) # 返回结果
t = Thread(target=classify, args=(batch_size,))
t.start()
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=26000)