Browse Source

增加目录请求

master
majiahui@haimaqingfan.com 2 years ago
parent
commit
64f307bde5
  1. 27
      flask_batch.py
  2. 4
      flask_test.py
  3. 38
      predict.py

27
flask_batch.py

@ -1,5 +1,5 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from flask import Flask, jsonify
from flask import request
from transformers import pipeline
@ -13,6 +13,21 @@ import threading
import time
import concurrent.futures
import requests
import socket
def get_host_ip():
"""
查询本机ip地址
:return: ip
"""
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(('8.8.8.8', 80))
ip = s.getsockname()[0]
finally:
s.close()
return ip
app = Flask(__name__)
app.config["JSON_AS_ASCII"] = False
@ -25,7 +40,7 @@ db_key_result = 'result'
batch_size = 32
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>", max_tokens=4096)
models_path = "/home/majiahui/project/models-llm/openbuddy-llama-7b-finetune"
models_path = "/home/majiahui/model-llm/openbuddy-llama-7b-finetune"
llm = LLM(model=models_path, tokenizer_mode="slow")
@ -71,6 +86,7 @@ def classify(batch_size): # 调用模型,设置最大batch_size
outputs = llm.generate(texts, sampling_params) # 调用模型
generated_text_list = [""] * len(texts)
print("outputs", outputs)
for i, output in enumerate(outputs):
index = output.request_id
generated_text = output.outputs[0].text
@ -98,14 +114,13 @@ def handle_query():
return jsonify(result_text) # 返回结果
@app.route("/articles_directory", methods=["POST"])
def articles_directory():
text = request.json["texts"] # 获取用户query中的文本 例如"I love you"
nums = request.json["nums"]
nums = int(nums)
url = "http://114.116.25.228:18000/predict"
url = "http://{}:18000/predict".format(str(get_host_ip()))
input_data = []
for i in range(nums):
@ -121,8 +136,8 @@ def articles_directory():
return jsonify(results) # 返回结果
t = Thread(target=classify, args=(batch_size,))
t.start()
if __name__ == "__main__":
t = Thread(target=classify, args=(batch_size,))
t.start()
app.run(debug=False, host='0.0.0.0', port=18000)

4
flask_test.py

@ -8,7 +8,7 @@ app.config["JSON_AS_ASCII"] = False
# prompts = [
# "生成论文小标题内容#问:论文题目是“大学生村官管理研究”,目录是“一、大学生村官管理现状分析\\n1.1 村官数量及分布情况\\n1.2 村官岗位设置及职责\\n1.3 村官工作绩效评估\\n\\n二、大学生村官管理存在的问题\\n2.1 村官队伍结构不合理\\n2.2 村官工作能力不足\\n2.3 村官管理制度不健全\\n\\n三、大学生村官管理对策研究\\n3.1 加强村官队伍建设\\n3.2 提高村官工作能力\\n3.3 完善村官管理制度\\n\\n四、大学生村官管理案例分析\\n4.1 案例一:某村大学生村官工作情况分析\\n4.2 案例二:某村大学生村官管理策略探讨\\n\\n五、大学生村官管理的未来发展趋势\\n5.1 多元化村官队伍建设\\n5.2 信息化村官管理模式\\n5.3 村官职业化发展\\n\\n六、大学生村官管理的政策建议\\n6.1 加强对大学生村官的培训和管理\\n6.2 完善大学生村官管理制度\\n6.3 提高大学生村官的待遇和福利\\n\\n七、结论与展望”,请把其中的小标题“3.3 完善村官管理制度”的内容补充完整,补充内容字数在800字左右\n答:\n"
# ] * 10
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>")
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>",max_tokens=4096)
#
model_path = '/home/majiahui/project/models-llm/openbuddy-llama-7b-finetune'
llm = LLM(model=model_path)
@ -40,4 +40,4 @@ def handle_query():
return jsonify(result_text) # 返回结果
if __name__ == "__main__":
app.run(host="0.0.0.0", port=15001, threaded=True, debug=False)
app.run(host="0.0.0.0", port=18000, threaded=True, debug=False)

38
predict.py

@ -1,14 +1,26 @@
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import time
from vllm import LLM, SamplingParams
prompts = [
"生成论文小标题内容#问:论文题目是“大学生村官管理研究”,目录是“一、大学生村官管理现状分析\\n1.1 村官数量及分布情况\\n1.2 村官岗位设置及职责\\n1.3 村官工作绩效评估\\n\\n二、大学生村官管理存在的问题\\n2.1 村官队伍结构不合理\\n2.2 村官工作能力不足\\n2.3 村官管理制度不健全\\n\\n三、大学生村官管理对策研究\\n3.1 加强村官队伍建设\\n3.2 提高村官工作能力\\n3.3 完善村官管理制度\\n\\n四、大学生村官管理案例分析\\n4.1 案例一:某村大学生村官工作情况分析\\n4.2 案例二:某村大学生村官管理策略探讨\\n\\n五、大学生村官管理的未来发展趋势\\n5.1 多元化村官队伍建设\\n5.2 信息化村官管理模式\\n5.3 村官职业化发展\\n\\n六、大学生村官管理的政策建议\\n6.1 加强对大学生村官的培训和管理\\n6.2 完善大学生村官管理制度\\n6.3 提高大学生村官的待遇和福利\\n\\n七、结论与展望”,请把其中的小标题“3.3 完善村官管理制度”的内容补充完整,补充内容字数在1500字左右\n答:\n"
]
# prompts = [
# "生成论文小标题内容#\n问:论文题目是《我国基层税务机关的税务风险管理研究——以Z市为例》,目录是“ 一、绪论\n1.1 研究背景和意义\n1.2 国内外相关研究综述\n1.3 研究内容和方法\n\n二、Z市税务风险管理现状分析\n2.1 Z市基层税务机关概况\n2.2 Z市税务风险管理现状\n2.3 Z市税务风险管理存在的问题\n\n三、Z市税务风险管理优化策略\n3.1 完善税务风险管理制度\n3.2 加强税务风险管理人员培训\n3.3 建立健全税务风险管理评估体系\n\n四、Z市税务风险管理效果评价\n4.1 数据收集和整理\n4.2 评价指标体系构建\n4.3 评价结果分析\n\n五、Z市税务风险管理改进建议\n5.1 改进策略和措施\n5.2 实施计划和步骤\n5.3 预期效果和成效评价\n\n六、结论与展望\n6.1 研究结论\n6.2 研究不足和展望\n”,请把其中的小标题“1.1 研究背景和意义”的内容补充完整,补充内容字数在1295字左右\n答:\n"
# ] * 50
# prompts = [
# "问:请列出张仲景的所有经方名称\n答:\n"
# ]
# "改写句子#\n问:改写这句话“它的外在影响因素是由环境和载荷功能影响的。”答:\n",
# "改写句子#\n问:改写这句话“它的外在影响因素是由环境和载荷功能影响的。”答:\n"
# ] * 100
prompts = []
with open("data/测试.txt") as f:
data = f.readlines()
for i in data:
data_dan = i.strip("\n")
if 10 < len(data_dan) < 250 :
prompts.append("改写句子#\n问:改写这句话“" + data_dan + "”答:\n")
print(prompts)
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>", max_tokens=2048)
models_path = "/home/majiahui/project/models-llm/openbuddy-llama-7b-finetune"
@ -20,13 +32,23 @@ outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
zishu = 0
# t2 = time.time()
generated_text_list = [i for i in prompts]
# for i, output in enumerate(outputs):
# index = output.request_id
# generated_text = output.outputs[0].text
# generated_text_list[int(index)] = generated_text
for i,output in enumerate(outputs):
generated_text = output.outputs[0].text
index = output.request_id
zishu += len(generated_text)
print("================================================================================")
generated_text_list[int(index)] += "\n==============================\n" + generated_text
for i in generated_text_list:
print(i)
print("=================================================================================")
print(f"Generated text: {generated_text}")
print("\n@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n")
t2 = time.time()
time_cost = t2-t1

Loading…
Cancel
Save