|
|
@ -1,5 +1,5 @@ |
|
|
|
import os |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "3" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
|
|
|
from flask import Flask, jsonify |
|
|
|
from flask import request |
|
|
|
from transformers import pipeline |
|
|
@ -13,6 +13,21 @@ import threading |
|
|
|
import time |
|
|
|
import concurrent.futures |
|
|
|
import requests |
|
|
|
import socket |
|
|
|
|
|
|
|
def get_host_ip(): |
|
|
|
""" |
|
|
|
查询本机ip地址 |
|
|
|
:return: ip |
|
|
|
""" |
|
|
|
try: |
|
|
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
|
|
|
s.connect(('8.8.8.8', 80)) |
|
|
|
ip = s.getsockname()[0] |
|
|
|
finally: |
|
|
|
s.close() |
|
|
|
|
|
|
|
return ip |
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
|
app.config["JSON_AS_ASCII"] = False |
|
|
@ -25,7 +40,7 @@ db_key_result = 'result' |
|
|
|
batch_size = 32 |
|
|
|
|
|
|
|
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>", max_tokens=4096) |
|
|
|
models_path = "/home/majiahui/project/models-llm/openbuddy-llama-7b-finetune" |
|
|
|
models_path = "/home/majiahui/model-llm/openbuddy-llama-7b-finetune" |
|
|
|
llm = LLM(model=models_path, tokenizer_mode="slow") |
|
|
|
|
|
|
|
|
|
|
@ -71,6 +86,7 @@ def classify(batch_size): # 调用模型,设置最大batch_size |
|
|
|
outputs = llm.generate(texts, sampling_params) # 调用模型 |
|
|
|
|
|
|
|
generated_text_list = [""] * len(texts) |
|
|
|
print("outputs", outputs) |
|
|
|
for i, output in enumerate(outputs): |
|
|
|
index = output.request_id |
|
|
|
generated_text = output.outputs[0].text |
|
|
@ -98,14 +114,13 @@ def handle_query(): |
|
|
|
|
|
|
|
return jsonify(result_text) # 返回结果 |
|
|
|
|
|
|
|
|
|
|
|
@app.route("/articles_directory", methods=["POST"]) |
|
|
|
def articles_directory(): |
|
|
|
text = request.json["texts"] # 获取用户query中的文本 例如"I love you" |
|
|
|
nums = request.json["nums"] |
|
|
|
|
|
|
|
nums = int(nums) |
|
|
|
url = "http://114.116.25.228:18000/predict" |
|
|
|
url = "http://{}:18000/predict".format(str(get_host_ip())) |
|
|
|
|
|
|
|
input_data = [] |
|
|
|
for i in range(nums): |
|
|
@ -121,8 +136,8 @@ def articles_directory(): |
|
|
|
|
|
|
|
return jsonify(results) # 返回结果 |
|
|
|
|
|
|
|
t = Thread(target=classify, args=(batch_size,)) |
|
|
|
t.start() |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
t = Thread(target=classify, args=(batch_size,)) |
|
|
|
t.start() |
|
|
|
app.run(debug=False, host='0.0.0.0', port=18000) |