Browse Source

首次提交

master
majiahui@haimaqingfan.com 8 months ago
commit
4e65ce94b0
  1. 21
      gunicorn_config.py
  2. 131
      mistral_api.py
  3. 181
      qwen_model_perdict_vllm_4.py
  4. 1
      run_api_gunicorn.sh
  5. 1
      run_model.sh

21
gunicorn_config.py

@ -0,0 +1,21 @@
# 并行工作线程数
workers = 8
# 监听内网端口5000【按需要更改】
bind = '0.0.0.0:12003'
loglevel = 'debug'
worker_class = "gevent"
# 设置守护进程【关闭连接时,程序仍在运行】
daemon = True
# 设置超时时间120s,默认为30s。按自己的需求进行设置
timeout = 120
# 设置访问日志和错误信息日志路径
accesslog = './logs/acess1.log'
errorlog = './logs/error1.log'
# access_log_format = '%(h) - %(t)s - %(u)s - %(s)s %(H)s'
# errorlog = '-' # 记录到标准输出
# 设置最大并发量
worker_connections = 20000

131
mistral_api.py

@ -0,0 +1,131 @@
from flask import Flask, jsonify
from flask import request
from transformers import pipeline
import redis
import uuid
import json
from threading import Thread
from vllm import LLM, SamplingParams
import time
import threading
import time
import concurrent.futures
import requests
import socket
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=2, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
db_key_queryset = 'queryset'
db_key_result = 'result'
db_key_error = 'error'
def smtp_f(name):
# 在下面的代码行中使用断点来调试脚本。
import smtplib
from email.mime.text import MIMEText
from email.header import Header
sender = '838878981@qq.com' # 发送邮箱
receivers = ['838878981@qq.com'] # 接收邮箱
auth_code = "jfqtutaiwrtdbcge" # 授权码
message = MIMEText('基础大模型出现错误,紧急', 'plain', 'utf-8')
message['From'] = Header("Sender<%s>" % sender) # 发送者
message['To'] = Header("Receiver<%s>" % receivers[0]) # 接收者
subject = name
message['Subject'] = Header(subject, 'utf-8')
try:
server = smtplib.SMTP_SSL('smtp.qq.com', 465)
server.login(sender, auth_code)
server.sendmail(sender, receivers, message.as_string())
print("邮件发送成功")
server.close()
except smtplib.SMTPException:
print("Error: 无法发送邮件")
@app.route("/predict", methods=["POST"])
def predict():
text = request.json["texts"] # 获取用户query中的文本 例如"I love you"
id_ = str(uuid.uuid1()) # 为query生成唯一标识
print("uuid: ", uuid)
d = {'id': id_, 'text': text} # 绑定文本和query id
try:
load_request_path = './request_data_logs/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(d, f2, ensure_ascii=False, indent=4)
redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis
redis_.sadd(db_key_querying, id_)
redis_.sadd(db_key_queryset, id_)
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200}
except:
return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 400}
smtp_f("vllm-main-drop")
return jsonify(return_text) # 返回结果
@app.route("/search", methods=["POST"])
def search():
id_ = request.json['id'] # 获取用户query中的文本 例如"I love you"
result = redis_.get(id_) # 获取该query的模型结果
try:
if result is not None:
result_path = result.decode('UTF-8')
with open(result_path, encoding='utf8') as f1:
# 加载文件的对象
result_dict = json.load(f1)
code = result_dict["status_code"]
texts = result_dict["texts"]
probabilities = result_dict["probabilities"]
if str(code) == 400:
redis_.rpush(db_key_error, json.dumps({"id": id_}))
return False
result_text = {'code': code, 'text': texts, 'probabilities': probabilities}
else:
querying_list = list(redis_.smembers(db_key_querying))
querying_set = set()
for i in querying_list:
querying_set.add(i.decode())
querying_bool = False
if id_ in querying_set:
querying_bool = True
query_list_json = redis_.lrange(db_key_query, 0, -1)
query_set_ids = set()
for i in query_list_json:
data_dict = json.loads(i)
query_id = data_dict['id']
query_set_ids.add(query_id)
query_bool = False
if id_ in query_set_ids:
query_bool = True
if querying_bool == True and query_bool == True:
result_text = {'code': "201", 'text': "", 'probabilities': None}
elif querying_bool == True and query_bool == False:
result_text = {'code': "202", 'text': "", 'probabilities': None}
else:
result_text = {'code': "203", 'text': "", 'probabilities': None}
load_request_path = './request_data_logs_203/{}.json'.format(id_)
with open(load_request_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(result_text, f2, ensure_ascii=False, indent=4)
except:
smtp_f("vllm-main")
result_text = {'code': "400", 'text': "", 'probabilities': None}
return jsonify(result_text) # 返回结果
if __name__ == "__main__":
app.run(debug=False, host='0.0.0.0', port=12006)

181
qwen_model_perdict_vllm_4.py

@ -0,0 +1,181 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import argparse
from typing import List, Tuple
from threading import Thread
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
# from vllm.utils import FlexibleArgumentParser
from flask import Flask, jsonify
from flask import request
import redis
import time
import json
# http接口服务
# app = FastAPI()
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=2, password="zhicheng123*")
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
db_key_query = 'query'
db_key_querying = 'querying'
db_key_result = 'result'
batch_size = 32
class log:
def __init__(self):
pass
def log(*args, **kwargs):
format = '%Y/%m/%d-%H:%M:%S'
format_h = '%Y-%m-%d'
value = time.localtime(int(time.time()))
dt = time.strftime(format, value)
dt_log_file = time.strftime(format_h, value)
log_file = 'log_file/access-%s' % dt_log_file + ".log"
if not os.path.exists(log_file):
with open(os.path.join(log_file), 'w', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
else:
with open(os.path.join(log_file), 'a+', encoding='utf-8') as f:
print(dt, *args, file=f, **kwargs)
def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine from the command line arguments."""
# model_dir = "/home/majiahui/project/models-llm/Qwen-0_5B-Chat"
model_dir = "/home/majiahui/project/models-llm/qwen2_0_5B_rewrite_lora_hebing"
args = EngineArgs(model_dir)
args.max_num_seqs = 16 # batch最大20条样本
args.gpu_memory_utilization = 0.3
# 加载模型
return LLMEngine.from_engine_args(args)
engine = initialize_engine()
def create_test_prompts(prompt_texts, query_ids, sampling_params) -> List[Tuple[str,str, SamplingParams]]:
"""Create a list of test prompts with their sampling parameters."""
return_list = []
for i,j in zip(prompt_texts, query_ids):
return_list.append((i, j, sampling_params))
return return_list
def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, str, SamplingParams]]):
"""Continuously process a list of prompts and handle the outputs."""
return_list = []
while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, query_id, sampling_params = test_prompts.pop(0)
engine.add_request(str(query_id), prompt, sampling_params)
request_outputs: List[RequestOutput] = engine.step()
for request_output in request_outputs:
if request_output.finished:
return_list.append(request_output)
return return_list
def main(prompt_texts, query_ids, sampling_params):
"""Main function that sets up and runs the prompt processing."""
test_prompts = create_test_prompts(prompt_texts, query_ids,sampling_params)
return process_requests(engine, test_prompts)
# chat对话接口
# @app.route("/predict/", methods=["POST"])
# def chat():
# # request = request.json()
# # query = request.get('query', None)
# # history = request.get('history', [])
# # system = request.get('system', 'You are a helpful assistant.')
# # stream = request.get("stream", False)
# # user_stop_words = request.get("user_stop_words",
# # []) # list[str],用户自定义停止句,例如:['Observation: ', 'Action: ']定义了2个停止句,遇到任何一个都会停止
#
# query = request.json['query']
#
#
# # 构造prompt
# # prompt_text, prompt_tokens = _build_prompt(generation_config, tokenizer, query, history=history, system=system)
#
# prompt_text = f"<|im_start|>user\n{query}\n<|im_end|>\n<|im_start|>assistant\n"
#
#
# return_output = main(prompt_text, sampling_params)
# return_info = {
# "request_id": return_output.request_id,
# "text": return_output.outputs[0].text
# }
#
# return jsonify(return_info)
def classify(batch_size): # 调用模型,设置最大batch_size
while True:
texts = []
query_ids = []
if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取
time.sleep(2)
continue
# for i in range(min(redis_.llen(db_key_query), batch_size)):
while True:
query = redis_.lpop(db_key_query) # 获取query的text
if query == None:
break
query = query.decode('UTF-8')
data_dict_path = json.loads(query)
path = data_dict_path['path']
with open(path, encoding='utf8') as f1:
# 加载文件的对象
data_dict = json.load(f1)
# query_ids.append(json.loads(query)['id'])
# texts.append(json.loads(query)['text']) # 拼接若干text 为batch
query_id = data_dict['id']
text = data_dict["text"]
query_ids.append(query_id)
texts.append(text)
if len(texts) == batch_size:
break
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, frequency_penalty=0.5, max_tokens=8192)
outputs = main(texts, query_ids, sampling_params)
print("预测完成")
generated_text_dict = {}
print("outputs", len(outputs))
for i, output in enumerate(outputs):
index = output.request_id
generated_text = output.outputs[0].text
generated_text_dict[index] = generated_text
for id_, output in generated_text_dict.items():
return_text = {"texts": output, "probabilities": None, "status_code": 200}
load_result_path = "./new_data_logs/{}.json".format(id_)
with open(load_result_path, 'w', encoding='utf8') as f2:
# ensure_ascii=False才能输入中文,否则是Unicode字符
# indent=2 JSON数据的缩进,美观
json.dump(return_text, f2, ensure_ascii=False, indent=4)
redis_.set(id_, load_result_path, 86400)
# redis_.set(id_, load_result_path, 30)
redis_.srem(db_key_querying, id_)
log.log('start at',
'query_id:{},load_result_path:{},return_text:{}'.format(
id_, load_result_path, return_text))
if __name__ == '__main__':
t = Thread(target=classify, args=(batch_size,))
t.start()

1
run_api_gunicorn.sh

@ -0,0 +1 @@
gunicorn mistral_api:app -c gunicorn_config.py

1
run_model.sh

@ -0,0 +1 @@
nohup python qwen_model_perdict_vllm_4.py > myout_model_qwen.file 2>&1 &
Loading…
Cancel
Save