From 7d6c636aff0330b4ba1d22a29a5bb1909b3fa123 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Mon, 6 Nov 2023 17:43:30 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E7=9B=AE=E5=BD=95=E8=AF=B7?= =?UTF-8?q?=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- articles_directory_predict.py | 83 ++++++++++++++++++++++++++++ vllm_predict_batch.py | 122 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 205 insertions(+) create mode 100644 articles_directory_predict.py create mode 100644 vllm_predict_batch.py diff --git a/articles_directory_predict.py b/articles_directory_predict.py new file mode 100644 index 0000000..3a923d5 --- /dev/null +++ b/articles_directory_predict.py @@ -0,0 +1,83 @@ +from flask import Flask, jsonify +from flask import request +from transformers import pipeline +import redis +import uuid +import json +from threading import Thread +from vllm import LLM, SamplingParams +import time +import threading +import time +import concurrent.futures +import requests +import socket + +def get_host_ip(): + """ + 查询本机ip地址 + :return: ip + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + + return ip + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=1000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return [] + +@app.route("/articles_directory", methods=["POST"]) +def articles_directory(): + text = request.json["texts"] # 获取用户query中的文本 例如"I love you" + nums = request.json["nums"] + + nums = int(nums) + url = "http://{}:18001/predict".format(str(get_host_ip())) + + input_data = [] + for i in range(nums): + input_data.append([url, {"texts": text}]) + + + with concurrent.futures.ThreadPoolExecutor() as executor: + # 使用submit方法将任务提交给线程池,并获取Future对象 + futures = [executor.submit(dialog_line_parse, i[0], i[1]) for i in input_data] + + # 使用as_completed获取已完成的任务,并获取返回值 + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + return jsonify(results) # 返回结果 + +if __name__ == "__main__": + app.run(debug=False, host='0.0.0.0', port=18000) \ No newline at end of file diff --git a/vllm_predict_batch.py b/vllm_predict_batch.py new file mode 100644 index 0000000..82694e5 --- /dev/null +++ b/vllm_predict_batch.py @@ -0,0 +1,122 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +from flask import Flask, jsonify +from flask import request +from transformers import pipeline +import redis +import uuid +import json +from threading import Thread +from vllm import LLM, SamplingParams +import time +import threading +import time +import concurrent.futures +import requests +import socket + +def get_host_ip(): + """ + 查询本机ip地址 + :return: ip + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + + return ip + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=11, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_query_articles_directory = 'query_articles_directory' +db_key_result = 'result' +batch_size = 32 + +sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="", max_tokens=4096) +models_path = "/home/majiahui/model-llm/openbuddy-llama-7b-finetune" +llm = LLM(model=models_path, tokenizer_mode="slow") + + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=1000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return [] + + +def classify(batch_size): # 调用模型,设置最大batch_size + while True: + texts = [] + query_ids = [] + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + time.sleep(2) + continue + for i in range(min(redis_.llen(db_key_query), batch_size)): + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + query_ids.append(json.loads(query)['id']) + texts.append(json.loads(query)['text']) # 拼接若干text 为batch + outputs = llm.generate(texts, sampling_params) # 调用模型 + + generated_text_list = [""] * len(texts) + print("outputs", outputs) + for i, output in enumerate(outputs): + index = output.request_id + generated_text = output.outputs[0].text + generated_text_list[int(index)] = generated_text + + + for (id_, output) in zip(query_ids, generated_text_list): + res = output + redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + text = request.json["texts"] # 获取用户query中的文本 例如"I love you" + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + d = {'id': id_, 'text': text} # 绑定文本和query id + redis_.rpush(db_key_query, json.dumps(d)) # 加入redis + while True: + result = redis_.get(id_) # 获取该query的模型结果 + if result is not None: + redis_.delete(id_) + result_text = {'code': "200", 'data': json.loads(result)} + break + time.sleep(1) + + return jsonify(result_text) # 返回结果 + + +t = Thread(target=classify, args=(batch_size,)) +t.start() + +if __name__ == "__main__": + app.run(debug=False, host='0.0.0.0', port=18001) \ No newline at end of file