You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
481 lines
15 KiB
481 lines
15 KiB
import os
|
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
|
from flask import Flask, jsonify
|
|
from flask import request
|
|
# from linshi import autotitle
|
|
import requests
|
|
import redis
|
|
import uuid
|
|
import json
|
|
from threading import Thread
|
|
import time
|
|
import re
|
|
import logging
|
|
from config_llama_api import Config
|
|
import numpy as np
|
|
import math
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
|
|
import torch
|
|
from vllm import LLM, SamplingParams
|
|
|
|
config = Config()
|
|
|
|
model_path = '/home/majiahui/models-LLM/openbuddy-llama-7b-finetune-v3'
|
|
# model_path = '/home/majiahui/models-LLM/openbuddy-openllama-7b-v5-fp16'
|
|
# model_path = '/home/majiahui/models-LLM/baichuan-vicuna-chinese-7b'
|
|
# model_path = '/home/majiahui/models-LLM/openbuddy-llama-7b-v1.4-fp16'
|
|
|
|
sampling_params = SamplingParams(temperature=0.95, top_p=0.7,presence_penalty=0.9,stop="</s>")
|
|
models_path = model_path
|
|
llm = LLM(model=models_path)
|
|
|
|
|
|
# WEIGHTS_NAME = "adapter_model.bin"
|
|
# checkpoint_dir = "/home/majiahui/project2/LLaMA-Efficient-Tuning/path_to_sft_checkpoint_paper_prompt_freeze_checkpoint_new_48000/checkpoint-16000"
|
|
# weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
|
|
# assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights."
|
|
# model_state_dict = torch.load(weights_file, map_location="cuda")
|
|
# model.load_state_dict(model_state_dict, strict=False) # skip missing keys
|
|
# model = model.cuda()
|
|
|
|
redis_title = "redis_title"
|
|
pool = redis.ConnectionPool(host=config.reids_ip, port=config.reids_port, max_connections=50, db=config.reids_db)
|
|
redis_ = redis.Redis(connection_pool=pool, decode_responses=True)
|
|
|
|
app = Flask(__name__)
|
|
app.config["JSON_AS_ASCII"] = False
|
|
|
|
# mulu_prompt = "为论文题目“{}”生成目录,要求只有一级标题和二级标题,一级标题使用中文数字 例如一、xxx;二级标题使用阿拉伯数字 例如1.1 xxx;一级标题不少于7个;每个一级标题至少包含3个二级标题"
|
|
# first_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的大标题“{}”的内容补充完整,补充内容字数在{}字左右"
|
|
# small_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的小标题“{}”的内容补充完整,补充内容字数在{}字左右"
|
|
# references_prompt = "论文题目是“{}”,目录是“{}”,请为这篇论文生成15篇左右的参考文献,要求其中有有中文参考文献不低于12篇,英文参考文献不低于2篇"
|
|
# thank_prompt = "请以“{}”为题写一篇论文的致谢"
|
|
# kaitibaogao_prompt = "请以《{}》为题目生成研究的主要的内容、背景、目的、意义,要求不少于1500字"
|
|
# chinese_abstract_prompt = "请以《{}》为题目生成论文摘要,要求生成的字数在600字左右"
|
|
# english_abstract_prompt = "请把“{}”这段文字翻译成英文"
|
|
# chinese_keyword_prompt = "请为“{}”这段论文摘要生成3-5个关键字,使用阿拉伯数字作为序号标注,例如“1.xxx \\n2.xxx \\n3.xxx \\n4.xxx \\n5.xxx \\n"
|
|
# english_keyword_prompt = "请把“{}”这几个关键字翻译成英文"
|
|
|
|
|
|
def normal_distribution(x):
|
|
y = np.exp(-(x - config.u) ** 2 / (2 * config.sig ** 2)) / (math.sqrt(2 * math.pi) * config.sig)
|
|
return y
|
|
|
|
|
|
def request_chatglm(prompt):
|
|
outputs = llm.generate([prompt], sampling_params)
|
|
generated_text = outputs[0].outputs[0].text
|
|
return generated_text
|
|
|
|
|
|
def chat_kaitibaogao(main_parameter):
|
|
response = request_chatglm(config.kaitibaogao_prompt.format(main_parameter[0]))
|
|
|
|
return response
|
|
|
|
|
|
def chat_abstract_keyword(main_parameter):
|
|
# 生成中文摘要
|
|
|
|
chinese_abstract = request_chatglm(config.chinese_abstract_prompt.format(main_parameter[0],main_parameter[1]))
|
|
|
|
# 生成英文的摘要
|
|
|
|
english_abstract = request_chatglm(config.english_abstract_prompt.format(chinese_abstract))
|
|
|
|
# 生成中文关键字
|
|
|
|
chinese_keyword = request_chatglm(config.chinese_keyword_prompt.format(chinese_abstract))
|
|
|
|
# 生成英文关键字
|
|
english_keyword = request_chatglm(config.english_keyword_prompt.format(chinese_keyword))
|
|
|
|
paper_abstract_keyword = {
|
|
"中文摘要": chinese_abstract,
|
|
"英文摘要": english_abstract,
|
|
"中文关键词": chinese_keyword,
|
|
"英文关键词": english_keyword
|
|
}
|
|
|
|
return paper_abstract_keyword
|
|
|
|
|
|
def chat_content(main_parameter):
|
|
'''
|
|
|
|
:param api_key:
|
|
:param uuid:
|
|
:param main_parameter:
|
|
:return:
|
|
'''
|
|
content_index = main_parameter[0]
|
|
title = main_parameter[1]
|
|
mulu = main_parameter[2]
|
|
subtitle = main_parameter[3]
|
|
prompt = main_parameter[4]
|
|
word_count = main_parameter[5]
|
|
|
|
if subtitle[:2] == "@@":
|
|
response = subtitle[2:]
|
|
else:
|
|
response = request_chatglm(prompt.format(title, mulu, subtitle, word_count))
|
|
if subtitle not in response:
|
|
response = subtitle + "\n" + response
|
|
|
|
print(prompt.format(title, mulu, subtitle, word_count), response)
|
|
return response
|
|
|
|
|
|
def chat_thanks(main_parameter):
|
|
'''
|
|
|
|
:param api_key:
|
|
:param uuid:
|
|
:param main_parameter:
|
|
:return:
|
|
'''
|
|
# title,
|
|
# thank_prompt
|
|
title = main_parameter[0]
|
|
prompt = main_parameter[1]
|
|
|
|
response = request_chatglm(prompt.format(title))
|
|
|
|
return response
|
|
|
|
|
|
def chat_references(main_parameter):
|
|
'''
|
|
|
|
:param api_key:
|
|
:param uuid:
|
|
:param main_parameter:
|
|
:return:
|
|
'''
|
|
# title,
|
|
# mulu,
|
|
# references_prompt
|
|
title = main_parameter[0]
|
|
mulu = main_parameter[1]
|
|
prompt = main_parameter[2]
|
|
|
|
response = request_chatglm(prompt.format(title, mulu))
|
|
|
|
# 加锁 读取resis并存储结果
|
|
|
|
return response
|
|
|
|
|
|
def small_title_tesk(small_title):
|
|
'''
|
|
顺序读取子任务
|
|
:return:
|
|
'''
|
|
task_type = small_title["task_type"]
|
|
main_parameter = small_title["main_parameter"]
|
|
|
|
# "task_type": "paper_content",
|
|
# "uuid": uuid,
|
|
# "main_parameter": [
|
|
# "task_type": "paper_content",
|
|
# "task_type": "chat_abstract",
|
|
# "task_type": "kaitibaogao",
|
|
|
|
if task_type == "kaitibaogao":
|
|
# result = chat_kaitibaogao(main_parameter)
|
|
result = ""
|
|
|
|
elif task_type == "chat_abstract":
|
|
result= chat_abstract_keyword(main_parameter)
|
|
|
|
|
|
elif task_type == "paper_content":
|
|
result = chat_content(main_parameter)
|
|
|
|
elif task_type == "thanks_task":
|
|
# result = chat_thanks(main_parameter)
|
|
result = ""
|
|
|
|
elif task_type == "references_task":
|
|
# result = chat_references(main_parameter)
|
|
result = ""
|
|
else:
|
|
result = ""
|
|
|
|
print(result, task_type, main_parameter)
|
|
return result, task_type
|
|
|
|
|
|
def main_prrcess(title):
|
|
mulu = request_chatglm(config.mulu_prompt.format(title))
|
|
mulu_list = mulu.split("\n")
|
|
mulu_list = [i.strip() for i in mulu_list if i != ""]
|
|
# mulu_str = "@".join(mulu_list)
|
|
|
|
|
|
mulu_list_bool = []
|
|
for i in mulu_list:
|
|
result_biaoti_list = re.findall(config.pantten_biaoti, i)
|
|
if result_biaoti_list != []:
|
|
mulu_list_bool.append((i, "一级标题"))
|
|
else:
|
|
mulu_list_bool.append((i, "二级标题"))
|
|
|
|
mulu_list_bool_part = mulu_list_bool[:3]
|
|
|
|
if mulu_list_bool_part[0][1] != "一级标题":
|
|
redis_.lpush(redis_title, json.dumps({"id": uuid, "title": title}, ensure_ascii=False)) # 加入redis
|
|
redis_.persist(redis_title)
|
|
return
|
|
if mulu_list_bool_part[0][1] == mulu_list_bool_part[1][1] == mulu_list_bool_part[2][1] == "一级标题":
|
|
redis_.lpush(redis_title, json.dumps({"id": uuid, "title": title}, ensure_ascii=False)) # 加入redis
|
|
redis_.persist(redis_title)
|
|
return
|
|
|
|
table_of_contents = []
|
|
|
|
thanks_references_bool_table = mulu_list_bool[-5:]
|
|
|
|
# thanks = "致谢"
|
|
# references = "参考文献"
|
|
for i in thanks_references_bool_table:
|
|
if config.references in i[0]:
|
|
mulu_list_bool.remove(i)
|
|
if config.thanks in i[0]:
|
|
mulu_list_bool.remove(i)
|
|
if config.excursus in i[0]:
|
|
mulu_list_bool.remove(i)
|
|
|
|
title_key = ""
|
|
# for i in mulu_list_bool:
|
|
# if i[1] == "一级标题":
|
|
# table_of_contents["@@" + i[0]] = []
|
|
# title_key = "@@" + i[0]
|
|
# else:
|
|
# table_of_contents[title_key].append(i[0])
|
|
|
|
for i in mulu_list_bool:
|
|
if i[1] == "一级标题":
|
|
paper_dan = {
|
|
"title": "@@" + i[0],
|
|
"small_title": [],
|
|
"word_count": 0
|
|
}
|
|
table_of_contents.append(paper_dan)
|
|
else:
|
|
table_of_contents[-1]["small_title"].append(i[0])
|
|
|
|
x_list = [0]
|
|
y_list = [normal_distribution(0)]
|
|
|
|
gradient = config.zong_gradient / len(table_of_contents)
|
|
for i in range(len(table_of_contents) - 1):
|
|
x_gradient = x_list[-1] + gradient
|
|
x_list.append(x_gradient)
|
|
y_list.append(normal_distribution(x_list[-1]))
|
|
|
|
dan_gradient = config.paper_word_count / sum(y_list)
|
|
|
|
for i in range(len(y_list)):
|
|
table_of_contents[i]["word_count"] = dan_gradient * y_list[i]
|
|
|
|
print(table_of_contents)
|
|
|
|
print(len(table_of_contents))
|
|
|
|
table_of_contents_new = []
|
|
for dabiaoti_index in range(len(table_of_contents)):
|
|
dabiaoti_dict = table_of_contents[dabiaoti_index]
|
|
table_of_contents_new.append([dabiaoti_dict["title"], 0])
|
|
for xiaobiaoti in dabiaoti_dict["small_title"]:
|
|
table_of_contents_new.append(
|
|
[xiaobiaoti, int(dabiaoti_dict["word_count"] / len(dabiaoti_dict["small_title"]))])
|
|
|
|
small_task_list = []
|
|
# api_key,
|
|
# index,
|
|
# title,
|
|
# mulu,
|
|
# subtitle,
|
|
# prompt
|
|
kaitibaogao_task = {
|
|
"task_type": "kaitibaogao",
|
|
"uuid": uuid,
|
|
"main_parameter": [title]
|
|
}
|
|
|
|
chat_abstract_task = {
|
|
"task_type": "chat_abstract",
|
|
"uuid": uuid,
|
|
"main_parameter": [title, mulu]
|
|
}
|
|
small_task_list.append(kaitibaogao_task)
|
|
small_task_list.append(chat_abstract_task)
|
|
content_index = 0
|
|
while True:
|
|
if content_index == len(table_of_contents_new):
|
|
break
|
|
subtitle, word_count = table_of_contents_new[content_index]
|
|
prompt = config.small_title_prompt
|
|
print(table_of_contents_new[1][0])
|
|
if content_index == 0 and table_of_contents_new[1][0][:2] == "@@" and subtitle[:2] == "@@":
|
|
subtitle, prompt, word_count = subtitle[2:], config.first_title_prompt, 800
|
|
|
|
if content_index == len(table_of_contents_new) - 1 and subtitle[:2] == "@@":
|
|
subtitle, prompt, word_count = subtitle[2:], config.first_title_prompt, 800
|
|
|
|
print("请求的所有参数",
|
|
content_index,
|
|
title,
|
|
subtitle,
|
|
prompt,
|
|
word_count)
|
|
|
|
paper_content = {
|
|
"task_type": "paper_content",
|
|
"uuid": uuid,
|
|
"main_parameter": [
|
|
content_index,
|
|
title,
|
|
mulu,
|
|
subtitle,
|
|
prompt,
|
|
word_count
|
|
]
|
|
}
|
|
|
|
small_task_list.append(paper_content)
|
|
content_index += 1
|
|
|
|
thanks_task = {
|
|
"task_type": "thanks_task",
|
|
"uuid": uuid,
|
|
"main_parameter": [
|
|
title,
|
|
config.thank_prompt
|
|
]
|
|
}
|
|
|
|
references_task = {
|
|
"task_type": "references_task",
|
|
"uuid": uuid,
|
|
"main_parameter": [
|
|
title,
|
|
mulu,
|
|
config.references_prompt
|
|
]
|
|
}
|
|
|
|
small_task_list.append(thanks_task)
|
|
small_task_list.append(references_task)
|
|
|
|
res = {
|
|
"num_small_task": len(small_task_list),
|
|
"tasking_num": 0,
|
|
"标题": title,
|
|
"目录": mulu,
|
|
"开题报告": "",
|
|
"任务书": "",
|
|
"中文摘要": "",
|
|
"英文摘要": "",
|
|
"中文关键词": "",
|
|
"英文关键词": "",
|
|
"正文": "",
|
|
"致谢": "",
|
|
"参考文献": "",
|
|
"table_of_contents": [""] * len(table_of_contents_new)
|
|
}
|
|
|
|
for small_task in small_task_list:
|
|
result, task_type = small_title_tesk(small_task)
|
|
|
|
if task_type == "kaitibaogao":
|
|
res["开题报告"] = result
|
|
|
|
elif task_type == "chat_abstract":
|
|
for i in result:
|
|
res[i] = result[i]
|
|
|
|
elif task_type == "paper_content":
|
|
content_index = small_task["main_parameter"][0]
|
|
res["table_of_contents"][content_index] = result
|
|
|
|
elif task_type == "thanks_task":
|
|
res["致谢"] = result
|
|
|
|
elif task_type == "references_task":
|
|
res["参考文献"] = result
|
|
|
|
return res
|
|
|
|
|
|
def classify(): # 调用模型,设置最大batch_size
|
|
while True:
|
|
if redis_.llen(redis_title) == 0: # 若队列中没有元素就继续获取
|
|
time.sleep(3)
|
|
continue
|
|
query = redis_.lpop(redis_title).decode('UTF-8') # 获取query的text
|
|
query = json.loads(query)
|
|
|
|
uuid = query['id']
|
|
texts = query["text"]
|
|
|
|
response = main_prrcess(texts)
|
|
print("res", response)
|
|
return_text = str({"texts": response, "probabilities": None, "status_code": 200})
|
|
|
|
uuid_path = os.path.join(config.project_data_txt_path, uuid)
|
|
|
|
os.makedirs(uuid_path)
|
|
|
|
paper_content_path = os.path.join(uuid_path, "paper_content.json")
|
|
print(uuid)
|
|
with open(paper_content_path, "w") as outfile:
|
|
json.dump(response, outfile)
|
|
|
|
save_word_paper = os.path.join(uuid_path, "paper.docx")
|
|
save_word_paper_start = os.path.join(uuid_path, "paper_start.docx")
|
|
os.system(
|
|
"java -Dfile.encoding=UTF-8 -jar '/home/majiahui/projert/chatglm/aiXieZuoPro.jar' '{}' '{}' '{}'".format(
|
|
paper_content_path,
|
|
save_word_paper,
|
|
save_word_paper_start))
|
|
redis_.set(uuid, return_text, 28800)
|
|
|
|
|
|
@app.route("/predict", methods=["POST"])
|
|
def handle_query():
|
|
print(request.remote_addr)
|
|
texts = request.json["texts"]
|
|
if texts is None:
|
|
return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402}
|
|
return jsonify(return_text)
|
|
|
|
id_ = str(uuid.uuid1()) # 为query生成唯一标识
|
|
d = {'id': id_, 'text': texts} # 绑定文本和query id
|
|
|
|
redis_.rpush(redis_title, json.dumps(d)) # 加入redis
|
|
while True:
|
|
result = redis_.get(id_) # 获取该query的模型结果
|
|
if result is not None:
|
|
result_text = {'code': "200", 'data': result.decode('UTF-8')}
|
|
break
|
|
else:
|
|
time.sleep(1)
|
|
|
|
return jsonify(result_text) # 返回结果
|
|
|
|
|
|
t = Thread(target=classify)
|
|
t.start()
|
|
|
|
if __name__ == "__main__":
|
|
fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log')
|
|
logging.basicConfig(
|
|
handlers=[fh],
|
|
level=logging.DEBUG,
|
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
|
datefmt='%a, %d %b %Y %H:%M:%S',
|
|
)
|
|
app.run(host="0.0.0.0", port=15000, threaded=True, debug=False)
|
|
|