普通大模型,未ppo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
7.7 KiB

3 months ago
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import argparse
from typing import List, Tuple
from threading import Thread
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
# from vllm.utils import FlexibleArgumentParser
from flask import Flask, jsonify
from flask import request
import redis
import time
import json
# http接口服务
# app = FastAPI()
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=3, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
db_key_result = 'result'
batch_size = 15
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
# model_dir = "/home/majiahui/project/models-llm/Qwen-0_5B-Chat"
# model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper"
# model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b_train_11_prompt_mistral_gpt_xiaobiaot_real_paper_2"
# model_dir = "/home/majiahui/project/models-llm/Qwen2.5-7B-Instruct-1M"
# model_dir = "/home/majiahui/project/models-llm/openbuddy-qwen2.5llamaify-7b-v23.1-200k"
model_dir = "/home/majiahui/project/models-llm/qwen2_5_7B_train_11_prompt_4_gpt_xiaobiaot_real_paper_1"
args = EngineArgs(model_dir)
args.max_num_seqs = 16 # batch最大20条样本
args.gpu_memory_utilization = 0.8
args.max_model_len=8192
# 加载模型
return LLMEngine.from_engine_args(args)
engine = initialize_engine()
def create_test_prompts(prompt_texts, query_ids, sampling_params_list) -> List[Tuple[str,str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return_list = []
for i,j,k in zip(prompt_texts, query_ids, sampling_params_list):
return_list.append((i, j, k))
return return_list
def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
return_list = []
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, query_id, sampling_params = test_prompts.pop(0)
engine.add_request(str(query_id), prompt, sampling_params)
request_outputs: List[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
return_list.append(request_output)
return return_list
def main(prompt_texts, query_ids, sampling_params_list):
"""Main function that sets up and runs the prompt processing."""
test_prompts = create_test_prompts(prompt_texts, query_ids, sampling_params_list)
return process_requests(engine, test_prompts)
# chat对话接口
# @app.route("/predict/", methods=["POST"])
# def chat():
# # request = request.json()
# # query = request.get('query', None)
# # history = request.get('history', [])
# # system = request.get('system', 'You are a helpful assistant.')
# # stream = request.get("stream", False)
# # user_stop_words = request.get("user_stop_words",
# # []) # list[str],用户自定义停止句,例如:['Observation: ', 'Action: ']定义了2个停止句,遇到任何一个都会停止
#
# query = request.json['query']
#
#
# # 构造prompt
# # prompt_text, prompt_tokens = _build_prompt(generation_config, tokenizer, query, history=history, system=system)
#
# prompt_text = f"<|im_start|>user\n{query}\n<|im_end|>\n<|im_start|>assistant\n"
#
#
# return_output = main(prompt_text, sampling_params)
# return_info = {
# "request_id": return_output.request_id,
# "text": return_output.outputs[0].text
# }
#
# return jsonify(return_info)
def classify(batch_size): # 调用模型,设置最大batch_size
while True:
texts = []
query_ids = []
sampling_params_list = []
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(2)
continue
# for i in range(min(redis_.llen(db_key_query), batch_size)):
while True:
query = redis_.lpop(db_key_query) # 获取query的text
if query == None:
break
query = query.decode('UTF-8')
data_dict_path = json.loads(query)
path = data_dict_path['path']
with open(path, encoding='utf8') as f1:
# 加载文件的对象
data_dict = json.load(f1)
# query_ids.append(json.loads(query)['id'])
# texts.append(json.loads(query)['text']) # 拼接若干text 为batch
query_id = data_dict['id']
print("query_id", query_id)
text = data_dict["text"]
model = data_dict["model"]
top_p = data_dict["top_p"]
temperature = data_dict["temperature"]
presence_penalty = 1.1
max_tokens = 8192
query_ids.append(query_id)
texts.append(text)
# sampling_params = SamplingParams(temperature=0.3, top_p=0.5, stop="<|end|>", presence_penalty=1.1, max_tokens=8192)
sampling_params_list.append(SamplingParams(
temperature=temperature,
top_p=top_p,
stop="<|end|>",
presence_penalty=presence_penalty,
max_tokens=max_tokens
))
if len(texts) == batch_size:
break
print("texts", len(texts))
print("query_ids", len(query_ids))
print("sampling_params_list", len(sampling_params_list))
outputs = main(texts, query_ids, sampling_params_list)
print("预测完成")
generated_text_dict = {}
print("outputs", len(outputs))
for i, output in enumerate(outputs):
index = output.request_id
print(index)
generated_text = output.outputs[0].text
generated_text_dict[index] = generated_text
print(generated_text_dict)
for id_, output in generated_text_dict.items():
return_text = {"texts": output, "probabilities": None, "status_code": 200}
load_result_path = "./new_data_logs/{}.json".format(id_)
with open(load_result_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(return_text, f2, ensure_ascii=False, indent=4)
redis_.set(id_, load_result_path, 86400)
# redis_.set(id_, load_result_path, 30)
redis_.srem(db_key_querying, id_)
log.log('start at',
'query_id:{},load_result_path:{},return_text:{}'.format(
id_, load_result_path, return_text))
if __name__ == '__main__':
t = Thread(target=classify, args=(batch_size,))
t.start()