You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
64 lines
2.9 KiB
64 lines
2.9 KiB
import os
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
import time
|
|
from vllm import LLM, SamplingParams
|
|
|
|
# prompts = [
|
|
# "生成论文小标题内容#\n问:论文题目是《我国基层税务机关的税务风险管理研究——以Z市为例》,目录是“ 一、绪论\n1.1 研究背景和意义\n1.2 国内外相关研究综述\n1.3 研究内容和方法\n\n二、Z市税务风险管理现状分析\n2.1 Z市基层税务机关概况\n2.2 Z市税务风险管理现状\n2.3 Z市税务风险管理存在的问题\n\n三、Z市税务风险管理优化策略\n3.1 完善税务风险管理制度\n3.2 加强税务风险管理人员培训\n3.3 建立健全税务风险管理评估体系\n\n四、Z市税务风险管理效果评价\n4.1 数据收集和整理\n4.2 评价指标体系构建\n4.3 评价结果分析\n\n五、Z市税务风险管理改进建议\n5.1 改进策略和措施\n5.2 实施计划和步骤\n5.3 预期效果和成效评价\n\n六、结论与展望\n6.1 研究结论\n6.2 研究不足和展望\n”,请把其中的小标题“1.1 研究背景和意义”的内容补充完整,补充内容字数在1295字左右\n答:\n"
|
|
# ] * 50
|
|
|
|
# prompts = [
|
|
# "改写句子#\n问:改写这句话“它的外在影响因素是由环境和载荷功能影响的。”答:\n",
|
|
# "改写句子#\n问:改写这句话“它的外在影响因素是由环境和载荷功能影响的。”答:\n"
|
|
# ] * 100
|
|
|
|
prompts = []
|
|
with open("data/测试.txt") as f:
|
|
data = f.readlines()
|
|
for i in data:
|
|
data_dan = i.strip("\n")
|
|
if 10 < len(data_dan) < 250 :
|
|
prompts.append("改写句子#\n问:改写这句话“" + data_dan + "”答:\n")
|
|
|
|
print(prompts)
|
|
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>", max_tokens=2048)
|
|
|
|
models_path = "/home/majiahui/project/models-llm/openbuddy-llama-7b-finetune"
|
|
llm = LLM(model=models_path, tokenizer_mode="slow")
|
|
|
|
t1 = time.time()
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
# Print the outputs.
|
|
zishu = 0
|
|
# t2 = time.time()
|
|
|
|
generated_text_list = [i for i in prompts]
|
|
# for i, output in enumerate(outputs):
|
|
# index = output.request_id
|
|
# generated_text = output.outputs[0].text
|
|
# generated_text_list[int(index)] = generated_text
|
|
for i,output in enumerate(outputs):
|
|
generated_text = output.outputs[0].text
|
|
index = output.request_id
|
|
zishu += len(generated_text)
|
|
generated_text_list[int(index)] += "\n==============================\n" + generated_text
|
|
|
|
|
|
|
|
for i in generated_text_list:
|
|
print(i)
|
|
print("\n@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@\n")
|
|
|
|
t2 = time.time()
|
|
time_cost = t2-t1
|
|
print(time_cost)
|
|
print("speed", zishu/time_cost)
|
|
#
|
|
zishu_one = zishu/time_cost
|
|
print(f"speed: {zishu_one} tokens/s")
|
|
# # from vllm import LLM
|
|
# #
|
|
# # llm = LLM(model="/home/majiahui/models-LLM/openbuddy-llama-7b-v1.4-fp16") # Name or path of your model
|
|
# # output = llm.generate("Hello, my name is")
|
|
# # print(output)
|
|
|