From b19375c6ff36171684937a74156288b6afb1e3f0 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Mon, 30 Oct 2023 11:11:24 +0800 Subject: [PATCH] =?UTF-8?q?=E7=AC=AC=E4=B8=80=E6=AC=A1=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 0 articles_directory_predict.py | 83 ++++++ flask_predict_batch_mistral.py | 414 +++++++++++++++++++++++++++ flask_predict_mistral_vllm.py | 67 +++++ main.py | 24 ++ redis_check_uuid_mistral.py | 87 ++++++ run_app_nohub_flask_predict_batch_mistral.sh | 1 + run_app_nohub_search_redis.sh | 1 + vllm_predict_batch.py | 122 ++++++++ 9 files changed, 799 insertions(+) create mode 100644 README.md create mode 100644 articles_directory_predict.py create mode 100644 flask_predict_batch_mistral.py create mode 100644 flask_predict_mistral_vllm.py create mode 100644 main.py create mode 100644 redis_check_uuid_mistral.py create mode 100644 run_app_nohub_flask_predict_batch_mistral.sh create mode 100644 run_app_nohub_search_redis.sh create mode 100644 vllm_predict_batch.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/articles_directory_predict.py b/articles_directory_predict.py new file mode 100644 index 0000000..4211db0 --- /dev/null +++ b/articles_directory_predict.py @@ -0,0 +1,83 @@ +from flask import Flask, jsonify +from flask import request +from transformers import pipeline +import redis +import uuid +import json +from threading import Thread +from vllm import LLM, SamplingParams +import time +import threading +import time +import concurrent.futures +import requests +import socket + +def get_host_ip(): + """ + 查询本机ip地址 + :return: ip + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + + return ip + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=1000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return [] + +@app.route("/articles_directory", methods=["POST"]) +def articles_directory(): + text = request.json["texts"] # 获取用户query中的文本 例如"I love you" + nums = request.json["nums"] + + nums = int(nums) + url = "http://{}:18001/predict".format(str(get_host_ip())) + + input_data = [] + for i in range(nums): + input_data.append([url, {"texts": "You are a helpful assistant.\n\nUser:{}\nAssistant:".format(text)}]) + + + with concurrent.futures.ThreadPoolExecutor() as executor: + # 使用submit方法将任务提交给线程池,并获取Future对象 + futures = [executor.submit(dialog_line_parse, i[0], i[1]) for i in input_data] + + # 使用as_completed获取已完成的任务,并获取返回值 + results = [future.result() for future in concurrent.futures.as_completed(futures)] + + return jsonify(results) # 返回结果 + +if __name__ == "__main__": + app.run(debug=False, host='0.0.0.0', port=18000) \ No newline at end of file diff --git a/flask_predict_batch_mistral.py b/flask_predict_batch_mistral.py new file mode 100644 index 0000000..b60e29a --- /dev/null +++ b/flask_predict_batch_mistral.py @@ -0,0 +1,414 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "2" +from flask import Flask, jsonify +from flask import request +import requests +import redis +import uuid +import json +from threading import Thread +import time +import re +import logging +from vllm import LLM, SamplingParams + + +logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) + +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=7, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_querying = 'querying' +db_key_queryset = 'queryset' +batch_size = 32 + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False + +import logging + +pattern = r"[。]" +RE_DIALOG = re.compile(r"\".*?\"|\'.*?\'|“.*?”") +fuhao_end_sentence = ["。", ",", "?", "!", "…"] + +# 加载模型 +sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=1.1,stop="", max_tokens=4096) +models_path = "/home/majiahui/model-llm/openbuddy-mistral-7b-v13.1" +llm = LLM(model=models_path, tokenizer_mode="slow") + + +class log: + def __init__(self): + pass + + def log(*args, **kwargs): + format = '%Y/%m/%d-%H:%M:%S' + format_h = '%Y-%m-%d' + value = time.localtime(int(time.time())) + dt = time.strftime(format, value) + dt_log_file = time.strftime(format_h, value) + log_file = 'log_file/access-%s' % dt_log_file + ".log" + if not os.path.exists(log_file): + with open(os.path.join(log_file), 'w', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + else: + with open(os.path.join(log_file), 'a+', encoding='utf-8') as f: + print(dt, *args, file=f, **kwargs) + + +def get_dialogs_index(line: str): + """ + 获取对话及其索引 + :param line 文本 + :return dialogs 对话内容 + dialogs_index: 对话位置索引 + other_index: 其他内容位置索引 + """ + dialogs = re.finditer(RE_DIALOG, line) + dialogs_text = re.findall(RE_DIALOG, line) + dialogs_index = [] + for dialog in dialogs: + all_ = [i for i in range(dialog.start(), dialog.end())] + dialogs_index.extend(all_) + other_index = [i for i in range(len(line)) if i not in dialogs_index] + + return dialogs_text, dialogs_index, other_index + + +def chulichangju_1(text, snetence_id, chulipangban_return_list, short_num): + fuhao = [",", "?", "!", "…"] + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + text_1 = text[:120] + text_2 = text[120:] + text_1_new = "" + if text_2 == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + return chulipangban_return_list + for i in range(len(text_1) - 1, -1, -1): + if text_1[i] in fuhao: + if i in dialogs_index: + continue + text_1_new = text_1[:i] + text_1_new += text_1[i] + chulipangban_return_list.append([text_1_new, snetence_id, short_num]) + if text_2 != "": + if i + 1 != 120: + text_2 = text_1[i + 1:] + text_2 + break + # else: + # chulipangban_return_list.append(text_1) + if text_1_new == "": + chulipangban_return_list.append([text_1, snetence_id, short_num]) + if text_2 != "": + short_num += 1 + chulipangban_return_list = chulichangju_1(text_2, snetence_id, chulipangban_return_list, short_num) + return chulipangban_return_list + + +def chulipangban_test_1(snetence_id, text): + # 引号处理 + + dialogs_text, dialogs_index, other_index = get_dialogs_index(text) + for dialogs_text_dan in dialogs_text: + text_dan_list = text.split(dialogs_text_dan) + text = dialogs_text_dan.join(text_dan_list) + + # text_new_str = "".join(text_new) + + sentence_list = text.split("。") + # sentence_list_new = [] + # for i in sentence_list: + # if i != "": + # sentence_list_new.append(i) + # sentence_list = sentence_list_new + sentence_batch_list = [] + sentence_batch_one = [] + sentence_batch_length = 0 + return_list = [] + + for sentence in sentence_list[:-1]: + if len(sentence) < 120: + sentence_batch_length += len(sentence) + sentence_batch_list.append([sentence + "。", snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + sentence_split_list = chulichangju_1(sentence, snetence_id, [], 0) + for sentence_short in sentence_split_list[:-1]: + sentence_batch_list.append(sentence_short) + sentence_batch_list.append(sentence_split_list[-1] + "。") + + if sentence_list[:-1] != "": + if len(sentence_list[-1]) < 120: + sentence_batch_length += len(sentence_list[-1]) + sentence_batch_list.append([sentence_list[-1], snetence_id, 0]) + # sentence_pre = autotitle.gen_synonyms_short(sentence) + # return_list.append(sentence_pre) + else: + sentence_split_list = chulichangju_1(sentence_list[-1], snetence_id, [], 0) + for sentence_short in sentence_split_list: + sentence_batch_list.append(sentence_short) + + return sentence_batch_list + + +def paragraph_test(texts: dict): + text_new = [] + for i, text in texts.items(): + text_list = chulipangban_test_1(i, text) + text_new.extend(text_list) + + # text_new_str = "".join(text_new) + return text_new + + +def batch_data_process(text_list): + sentence_batch_length = 0 + sentence_batch_one = [] + sentence_batch_list = [] + + for sentence in text_list: + sentence_batch_length += len(sentence[0]) + sentence_batch_one.append(sentence) + if sentence_batch_length > 500: + sentence_batch_length = 0 + sentence_ = sentence_batch_one.pop(-1) + sentence_batch_list.append(sentence_batch_one) + sentence_batch_one = [] + sentence_batch_one.append(sentence_) + sentence_batch_list.append(sentence_batch_one) + return sentence_batch_list + + +def batch_predict(batch_data_list): + ''' + 一个bacth数据预测 + @param data_text: + @return: + ''' + batch_data_list_new = [] + batch_data_text_list = [] + batch_data_snetence_id_list = [] + for i in batch_data_list: + batch_data_text_list.append(i[0]) + batch_data_snetence_id_list.append(i[1:]) + # batch_pre_data_list = autotitle.generate_beam_search_batch(batch_data_text_list) + batch_pre_data_list = batch_data_text_list + for text, sentence_id in zip(batch_pre_data_list, batch_data_snetence_id_list): + batch_data_list_new.append([text] + sentence_id) + + return batch_data_list_new + + +def predict_data_post_processing(text_list): + text_list_sentence = [] + # text_list_sentence.append([text_list[0][0], text_list[0][1]]) + + for i in range(len(text_list)): + if text_list[i][2] != 0: + text_list_sentence[-1][0] += text_list[i][0] + else: + text_list_sentence.append([text_list[i][0], text_list[i][1]]) + + return_list = {} + sentence_one = [] + sentence_id = text_list_sentence[0][1] + for i in text_list_sentence: + if i[1] == sentence_id: + sentence_one.append(i[0]) + else: + return_list[sentence_id] = "".join(sentence_one) + sentence_id = i[1] + sentence_one = [] + sentence_one.append(i[0]) + if sentence_one != []: + return_list[sentence_id] = "".join(sentence_one) + return return_list + + +# def main(text:list): +# # text_list = paragraph_test(text) +# # batch_data = batch_data_process(text_list) +# # text_list = [] +# # for i in batch_data: +# # text_list.extend(i) +# # return_list = predict_data_post_processing(text_list) +# # return return_list +def pre_sentence_ulit(sentence): + if "改写后:" in sentence: + sentence_lable_index = sentence.index("改写后:") + sentence = sentence[sentence_lable_index + 4:] + + return sentence + +def main(texts: dict): + text_list = paragraph_test(texts) + + text_info = [] + text_sentence = [] + text_list_new = [] + + # for i in text_list: + # pre = one_predict(i) + # text_list_new.append(pre) + + # vllm预测 + for i in text_list: + if len(i[0]) > 7: + text = "You are a helpful assistant.\n\nUser:改写下面这句话,要求意思接近但是改动幅度比较大,字数只能多不能少:\n{}\nAssistant:".format(i[0]) + else: + text = "You are a helpful assistant.\n\nUser:下面词不做任何变化:\n{}\nAssistant:".format(i[0]) + text_sentence.append(text) + text_info.append([i[1], i[2]]) + + + outputs = llm.generate(text_sentence, sampling_params) # 调用模型 + + generated_text_list = [""] * len(text_sentence) + + # generated_text_list = ["" if len(i[0]) > 5 else i[0] for i in text_list] + + for i, output in enumerate(outputs): + index = output.request_id + generated_text = output.outputs[0].text + generated_text = pre_sentence_ulit(generated_text) + generated_text_list[int(index)] = generated_text + + + for i in range(len(text_list)): + if len(text_list[i][0]) > 7: + continue + else: + generated_text_list[i] = text_list[i][0] + + for i, j in zip(generated_text_list, text_info): + text_list_new.append([i] + j) + + return_list = predict_data_post_processing(text_list_new) + return return_list + + +# @app.route('/droprepeat/', methods=['POST']) +# def sentence(): +# print(request.remote_addr) +# texts = request.json["texts"] +# text_type = request.json["text_type"] +# print("原始语句" + str(texts)) +# # question = question.strip('。、!??') +# +# if isinstance(texts, dict): +# texts_list = [] +# y_pred_label_list = [] +# position_list = [] +# +# # texts = texts.replace('\'', '\"') +# if texts is None: +# return_text = {"texts": "输入了空值", "probabilities": None, "status_code": False} +# return jsonify(return_text) +# else: +# assert text_type in ['focus', 'chapter'] +# if text_type == 'focus': +# texts_list = main(texts) +# if text_type == 'chapter': +# texts_list = main(texts) +# return_text = {"texts": texts_list, "probabilities": None, "status_code": True} +# else: +# return_text = {"texts": "输入格式应该为list", "probabilities": None, "status_code": False} +# return jsonify(return_text) + + +def classify(): # 调用模型,设置最大batch_size + while True: + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + time.sleep(3) + continue + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + data_dict_path = json.loads(query) + path = data_dict_path['path'] + # text_type = data_dict["text_type"] + + with open(path, encoding='utf8') as f1: + # 加载文件的对象 + data_dict = json.load(f1) + + query_id = data_dict['id'] + texts = data_dict["text"] + text_type = data_dict["text_type"] + + assert text_type in ['focus', 'chapter'] + if text_type == 'focus': + texts_list = main(texts) + elif text_type == 'chapter': + texts_list = main(texts) + else: + texts_list = [] + + return_text = {"texts": texts_list, "probabilities": None, "status_code": 200} + load_result_path = "./new_data_logs/{}.json".format(query_id) + + print("query_id: ", query_id) + print("load_result_path: ", load_result_path) + + with open(load_result_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(return_text, f2, ensure_ascii=False, indent=4) + debug_id_1 = 1 + redis_.set(query_id, load_result_path, 86400) + debug_id_2 = 2 + redis_.srem(db_key_querying, query_id) + debug_id_3 = 3 + log.log('start at', + 'query_id:{},load_result_path:{},return_text:{}, debug_id_1:{}, debug_id_2:{}, debug_id_3:{}'.format( + query_id, load_result_path, return_text, debug_id_1, debug_id_2, debug_id_3)) + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + print(request.remote_addr) + texts = request.json["texts"] + text_type = request.json["text_type"] + if texts is None: + return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402} + return jsonify(return_text) + if isinstance(texts, dict): + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + print("uuid: ", uuid) + d = {'id': id_, 'text': texts, "text_type": text_type} # 绑定文本和query id + + load_request_path = './request_data_logs/{}.json'.format(id_) + with open(load_request_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(d, f2, ensure_ascii=False, indent=4) + redis_.rpush(db_key_query, json.dumps({"id": id_, "path": load_request_path})) # 加入redis + redis_.sadd(db_key_querying, id_) + redis_.sadd(db_key_queryset, id_) + return_text = {"texts": {'id': id_, }, "probabilities": None, "status_code": 200} + print("ok") + else: + return_text = {"texts": "输入格式应该为字典", "probabilities": None, "status_code": 401} + return jsonify(return_text) # 返回结果 + + +t = Thread(target=classify) +t.start() + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG, # 控制台打印的日志级别 + filename='rewrite.log', + filemode='a', ##模式,有w和a,w就是写模式,每次都会重新写日志,覆盖之前的日志 + # a是追加模式,默认如果不写的话,就是追加模式 + format= + '%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s' + # 日志格式 + ) + app.run(host="0.0.0.0", port=14002, threaded=True, debug=False) diff --git a/flask_predict_mistral_vllm.py b/flask_predict_mistral_vllm.py new file mode 100644 index 0000000..2367ae0 --- /dev/null +++ b/flask_predict_mistral_vllm.py @@ -0,0 +1,67 @@ +import flask +from transformers import pipeline +import redis +import uuid +import json +from threading import Thread +import time + +app = flask.Flask(__name__) +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=5, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_result = 'result' + +sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=1.1,stop="", max_tokens=4096) +models_path = "/home/majiahui/model-llm/openbuddy-mistral-7b-v13.1" +llm = LLM(model=models_path, tokenizer_mode="slow") + +def mistral_vllm_models(texts): + outputs = llm.generate(texts, sampling_params) # 调用模型 + + generated_text_list = [""] * len(texts) + + # generated_text_list = ["" if len(i[0]) > 5 else i[0] for i in text_list] + + for i, output in enumerate(outputs): + index = output.request_id + generated_text = output.outputs[0].text + generated_text_list[int(index)] = generated_text + + +def classify(batch_size): # 调用模型,设置最大batch_size + while True: + texts = [] + query_ids = [] + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + continue + for i in range(min(redis_.llen(db_key_query), batch_size)): + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + query_ids.append(json.loads(query)['id']) + texts.append(json.loads(query)['text']) # 拼接若干text 为batch + result = mistral_vllm_models(texts) # 调用模型 + for (id_, res) in zip(query_ids, result): + res['score'] = str(res['score']) + redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + text = flask.request.form['text'] # 获取用户query中的文本 例如"I love you" + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + d = {'id': id_, 'text': text} # 绑定文本和query id + redis_.rpush(db_key_query, json.dumps(d)) # 加入redis + while True: + result = redis_.get(id_) # 获取该query的模型结果 + if result is not None: + redis_.delete(id_) + result_text = {'code': "200", 'data': result.decode('UTF-8')} + break + return flask.jsonify(result_text) # 返回结果 + + +if __name__ == "__main__": + t = Thread(target=classify) + t.start() + app.run(debug=False, host='127.0.0.1', port=9000) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..96b1c96 --- /dev/null +++ b/main.py @@ -0,0 +1,24 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "You are a helpful assistant.\n\nUser:张亮的爸爸叫张明,张明的爸爸有三个孩子,大儿子叫张大,二儿子叫张昊,三儿子叫什么?\nAssistant:", + "You are a helpful assistant.\n\nUser:你好\nAssistant:", + "You are a helpful assistant.\n\nUser:1+1等于几\nAssistant:", + "You are a helpful assistant.\n\nUser:你是谁\nAssistant:", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0, top_p=1, presence_penalty=0.9, max_tokens=1024) + +# Create an LLM. +llm = LLM(model="/home/majiahui/project/models-llm/openbuddy-mistral-7b-v13.1", trust_remote_code=True) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/redis_check_uuid_mistral.py b/redis_check_uuid_mistral.py new file mode 100644 index 0000000..863ad97 --- /dev/null +++ b/redis_check_uuid_mistral.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- + +""" +@Time : 2023/3/2 19:31 +@Author : +@FileName: +@Software: +@Describe: +""" +# +# import redis +# +# redis_pool = redis.ConnectionPool(host='127.0.0.1', port=6379, password='', db=0) +# redis_conn = redis.Redis(connection_pool=redis_pool) +# +# +# name_dict = { +# 'name_4' : 'Zarten_4', +# 'name_5' : 'Zarten_5' +# } +# redis_conn.mset(name_dict) + +import flask +import redis +import uuid +import json +from threading import Thread +import time + +app = flask.Flask(__name__) +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=100, db=7, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_querying = 'querying' + +@app.route("/search", methods=["POST"]) +def handle_query(): + id_ = flask.request.json['id'] # 获取用户query中的文本 例如"I love you" + result = redis_.get(id_) # 获取该query的模型结果 + if result is not None: + # redis_.delete(id_) + result_path = result.decode('UTF-8') + with open(result_path, encoding='utf8') as f1: + # 加载文件的对象 + result_dict = json.load(f1) + texts = result_dict["texts"] + probabilities = result_dict["probabilities"] + result_text = {'code': 200, 'text': texts, 'probabilities': probabilities} + else: + querying_list = list(redis_.smembers("querying")) + querying_set = set() + for i in querying_list: + querying_set.add(i.decode()) + + querying_bool = False + if id_ in querying_set: + querying_bool = True + + query_list_json = redis_.lrange(db_key_query, 0, -1) + query_set_ids = set() + for i in query_list_json: + data_dict = json.loads(i) + query_id = data_dict['id'] + query_set_ids.add(query_id) + + query_bool = False + if id_ in query_set_ids: + query_bool = True + + if querying_bool == True and query_bool == True: + result_text = {'code': "201", 'text': "", 'probabilities': None} + elif querying_bool == True and query_bool == False: + result_text = {'code': "202", 'text': "", 'probabilities': None} + else: + result_text = {'code': "203", 'text': "", 'probabilities': None} + load_request_path = './request_data_logs_203/{}.json'.format(id_) + with open(load_request_path, 'w', encoding='utf8') as f2: + # ensure_ascii=False才能输入中文,否则是Unicode字符 + # indent=2 JSON数据的缩进,美观 + json.dump(result_text, f2, ensure_ascii=False, indent=4) + + return flask.jsonify(result_text) # 返回结果 + + +if __name__ == "__main__": + app.run(debug=False, host='0.0.0.0', port=14003) diff --git a/run_app_nohub_flask_predict_batch_mistral.sh b/run_app_nohub_flask_predict_batch_mistral.sh new file mode 100644 index 0000000..f02942c --- /dev/null +++ b/run_app_nohub_flask_predict_batch_mistral.sh @@ -0,0 +1 @@ +nohup python flask_predict_batch_mistral.py > myout.flask_predict_batch_mistral.logs 2>&1 & diff --git a/run_app_nohub_search_redis.sh b/run_app_nohub_search_redis.sh new file mode 100644 index 0000000..84b0dcb --- /dev/null +++ b/run_app_nohub_search_redis.sh @@ -0,0 +1 @@ +nohup python redis_check_uuid_mistral.py > myout.redis_check_uuid_mistral.logs 2>&1 & diff --git a/vllm_predict_batch.py b/vllm_predict_batch.py new file mode 100644 index 0000000..bc8dad6 --- /dev/null +++ b/vllm_predict_batch.py @@ -0,0 +1,122 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "3" +from flask import Flask, jsonify +from flask import request +from transformers import pipeline +import redis +import uuid +import json +from threading import Thread +from vllm import LLM, SamplingParams +import time +import threading +import time +import concurrent.futures +import requests +import socket + +def get_host_ip(): + """ + 查询本机ip地址 + :return: ip + """ + try: + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + s.connect(('8.8.8.8', 80)) + ip = s.getsockname()[0] + finally: + s.close() + + return ip + +app = Flask(__name__) +app.config["JSON_AS_ASCII"] = False +pool = redis.ConnectionPool(host='localhost', port=63179, max_connections=50,db=11, password="zhicheng123*") +redis_ = redis.Redis(connection_pool=pool, decode_responses=True) + +db_key_query = 'query' +db_key_query_articles_directory = 'query_articles_directory' +db_key_result = 'result' +batch_size = 32 + +sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="", max_tokens=4096) +models_path = "/home/majiahui/project/models-llm/openbuddy-mistral-7b-v13.1" +llm = LLM(model=models_path, tokenizer_mode="slow") + + +def dialog_line_parse(url, text): + """ + 将数据输入模型进行分析并输出结果 + :param url: 模型url + :param text: 进入模型的数据 + :return: 模型返回结果 + """ + + response = requests.post( + url, + json=text, + timeout=1000 + ) + if response.status_code == 200: + return response.json() + else: + # logger.error( + # "【{}】 Failed to get a proper response from remote " + # "server. Status Code: {}. Response: {}" + # "".format(url, response.status_code, response.text) + # ) + print("【{}】 Failed to get a proper response from remote " + "server. Status Code: {}. Response: {}" + "".format(url, response.status_code, response.text)) + print(text) + return [] + + +def classify(batch_size): # 调用模型,设置最大batch_size + while True: + texts = [] + query_ids = [] + if redis_.llen(db_key_query) == 0: # 若队列中没有元素就继续获取 + time.sleep(2) + continue + for i in range(min(redis_.llen(db_key_query), batch_size)): + query = redis_.lpop(db_key_query).decode('UTF-8') # 获取query的text + query_ids.append(json.loads(query)['id']) + texts.append(json.loads(query)['text']) # 拼接若干text 为batch + outputs = llm.generate(texts, sampling_params) # 调用模型 + + generated_text_list = [""] * len(texts) + print("outputs", outputs) + for i, output in enumerate(outputs): + index = output.request_id + generated_text = output.outputs[0].text + generated_text_list[int(index)] = generated_text + + + for (id_, output) in zip(query_ids, generated_text_list): + res = output + redis_.set(id_, json.dumps(res)) # 将模型结果送回队列 + + +@app.route("/predict", methods=["POST"]) +def handle_query(): + text = request.json["texts"] # 获取用户query中的文本 例如"I love you" + id_ = str(uuid.uuid1()) # 为query生成唯一标识 + d = {'id': id_, 'text': text} # 绑定文本和query id + redis_.rpush(db_key_query, json.dumps(d)) # 加入redis + while True: + result = redis_.get(id_) # 获取该query的模型结果 + if result is not None: + redis_.delete(id_) + result_text = {'code': "200", 'data': json.loads(result)} + break + time.sleep(1) + + return jsonify(result_text) # 返回结果 + + +t = Thread(target=classify, args=(batch_size,)) +t.start() + +if __name__ == "__main__": + app.run(debug=False, host='0.0.0.0', port=18001) \ No newline at end of file