import os os.environ["CUDA_VISIBLE_DEVICES"] = "0" from flask import Flask, jsonify from flask import request # from linshi import autotitle import requests import redis import uuid import json from threading import Thread import time import re import logging from config_llama_api import Config import numpy as np import math from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer import torch config = Config() # model_path = '/home/majiahui/models-LLM/openbuddy-llama-7b-finetune' # model_path = '/home/majiahui/models-LLM/openbuddy-openllama-7b-v5-fp16' # model_path = '/home/majiahui/models-LLM/baichuan-vicuna-chinese-7b' model_path = '/home/majiahui/models-LLM/openbuddy-llama-7b-finetune' model = AutoModelForCausalLM.from_pretrained( model_path, device_map="auto", trust_remote_code=True, torch_dtype=torch.float16 # torch_dtype=torch.bfloat16 ) tokenizer = AutoTokenizer.from_pretrained(model_path,use_fast=False) WEIGHTS_NAME = "adapter_model.bin" checkpoint_dir = "/home/majiahui/project2/LLaMA-Efficient-Tuning/path_to_sft_checkpoint_paper_prompt_freeze_checkpoint_new_48000/checkpoint-16000" weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights." model_state_dict = torch.load(weights_file, map_location="cuda") model.load_state_dict(model_state_dict, strict=False) # skip missing keys model = model.cuda() redis_title = "redis_title" pool = redis.ConnectionPool(host=config.reids_ip, port=config.reids_port, max_connections=50, db=config.reids_db) redis_ = redis.Redis(connection_pool=pool, decode_responses=True) app = Flask(__name__) app.config["JSON_AS_ASCII"] = False # mulu_prompt = "为论文题目“{}”生成目录,要求只有一级标题和二级标题,一级标题使用中文数字 例如一、xxx;二级标题使用阿拉伯数字 例如1.1 xxx;一级标题不少于7个;每个一级标题至少包含3个二级标题" # first_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的大标题“{}”的内容补充完整,补充内容字数在{}字左右" # small_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的小标题“{}”的内容补充完整,补充内容字数在{}字左右" # references_prompt = "论文题目是“{}”,目录是“{}”,请为这篇论文生成15篇左右的参考文献,要求其中有有中文参考文献不低于12篇,英文参考文献不低于2篇" # thank_prompt = "请以“{}”为题写一篇论文的致谢" # kaitibaogao_prompt = "请以《{}》为题目生成研究的主要的内容、背景、目的、意义,要求不少于1500字" # chinese_abstract_prompt = "请以《{}》为题目生成论文摘要,要求生成的字数在600字左右" # english_abstract_prompt = "请把“{}”这段文字翻译成英文" # chinese_keyword_prompt = "请为“{}”这段论文摘要生成3-5个关键字,使用阿拉伯数字作为序号标注,例如“1.xxx \\n2.xxx \\n3.xxx \\n4.xxx \\n5.xxx \\n" # english_keyword_prompt = "请把“{}”这几个关键字翻译成英文" def normal_distribution(x): y = np.exp(-(x - config.u) ** 2 / (2 * config.sig ** 2)) / (math.sqrt(2 * math.pi) * config.sig) return y def request_chatglm(prompt): do_sample = False top_p = 0.7 temperature = 0.9 gen_kwargs = {"do_sample": do_sample, "top_p": top_p, "temperature": temperature} input_ids = tokenizer.encode(prompt, return_tensors='pt').to('cuda') # input_txt_list = [prompt1,prompt2,prompt3] # # input_ids = tokenizer.batch_encode_plus(input_txt_list, return_tensors="pt",padding=True)["input_ids"].to('cuda') # print(input_ids) with torch.no_grad(): output_ids = model.generate( input_ids=input_ids, max_new_tokens=2000, eos_token_id=tokenizer.eos_token_id, **gen_kwargs ) # print(output_ids) output_ids = output_ids.tolist() response = tokenizer.decode(output_ids[0][len(input_ids[0]):], skip_special_tokens=True) return response def chat_kaitibaogao(main_parameter): response = request_chatglm(config.kaitibaogao_prompt.format(main_parameter[0])) return response def chat_abstract_keyword(main_parameter): # 生成中文摘要 chinese_abstract = request_chatglm(config.chinese_abstract_prompt.format(main_parameter[0])) # 生成英文的摘要 english_abstract = request_chatglm(config.english_abstract_prompt.format(chinese_abstract)) # 生成中文关键字 chinese_keyword = request_chatglm(config.chinese_keyword_prompt.format(chinese_abstract)) # 生成英文关键字 english_keyword = request_chatglm(config.english_keyword_prompt.format(chinese_keyword)) paper_abstract_keyword = { "中文摘要": chinese_abstract, "英文摘要": english_abstract, "中文关键词": chinese_keyword, "英文关键词": english_keyword } return paper_abstract_keyword def chat_content(main_parameter): ''' :param api_key: :param uuid: :param main_parameter: :return: ''' content_index = main_parameter[0] title = main_parameter[1] mulu = main_parameter[2] subtitle = main_parameter[3] prompt = main_parameter[4] word_count = main_parameter[5] if subtitle[:2] == "@@": response = subtitle[2:] else: response = request_chatglm(prompt.format(title, mulu, subtitle, word_count)) if subtitle not in response: response = subtitle + "\n" + response print(prompt.format(title, mulu, subtitle, word_count), response) return response def chat_thanks(main_parameter): ''' :param api_key: :param uuid: :param main_parameter: :return: ''' # title, # thank_prompt title = main_parameter[0] prompt = main_parameter[1] response = request_chatglm(prompt.format(title)) return response def chat_references(main_parameter): ''' :param api_key: :param uuid: :param main_parameter: :return: ''' # title, # mulu, # references_prompt title = main_parameter[0] mulu = main_parameter[1] prompt = main_parameter[2] response = request_chatglm(prompt.format(title, mulu)) # 加锁 读取resis并存储结果 return response def small_title_tesk(small_title): ''' 顺序读取子任务 :return: ''' task_type = small_title["task_type"] main_parameter = small_title["main_parameter"] # "task_type": "paper_content", # "uuid": uuid, # "main_parameter": [ # "task_type": "paper_content", # "task_type": "chat_abstract", # "task_type": "kaitibaogao", if task_type == "kaitibaogao": # result = chat_kaitibaogao(main_parameter) result = "" elif task_type == "chat_abstract": # result= chat_abstract_keyword(main_parameter) result = "" elif task_type == "paper_content": result = chat_content(main_parameter) elif task_type == "thanks_task": # result = chat_thanks(main_parameter) result = "" elif task_type == "references_task": # result = chat_references(main_parameter) result = "" else: result = "" print(result, task_type, main_parameter) return result, task_type def main_prrcess(title): mulu = request_chatglm(config.mulu_prompt.format(title)) mulu_list = mulu.split("\n") mulu_list = [i.strip() for i in mulu_list if i != ""] # mulu_str = "@".join(mulu_list) mulu_list_bool = [] for i in mulu_list: result_biaoti_list = re.findall(config.pantten_biaoti, i) if result_biaoti_list != []: mulu_list_bool.append((i, "一级标题")) else: mulu_list_bool.append((i, "二级标题")) mulu_list_bool_part = mulu_list_bool[:3] if mulu_list_bool_part[0][1] != "一级标题": redis_.lpush(redis_title, json.dumps({"id": uuid, "title": title}, ensure_ascii=False)) # 加入redis redis_.persist(redis_title) return if mulu_list_bool_part[0][1] == mulu_list_bool_part[1][1] == mulu_list_bool_part[2][1] == "一级标题": redis_.lpush(redis_title, json.dumps({"id": uuid, "title": title}, ensure_ascii=False)) # 加入redis redis_.persist(redis_title) return table_of_contents = [] thanks_references_bool_table = mulu_list_bool[-5:] # thanks = "致谢" # references = "参考文献" for i in thanks_references_bool_table: if config.references in i[0]: mulu_list_bool.remove(i) if config.thanks in i[0]: mulu_list_bool.remove(i) if config.excursus in i[0]: mulu_list_bool.remove(i) title_key = "" # for i in mulu_list_bool: # if i[1] == "一级标题": # table_of_contents["@@" + i[0]] = [] # title_key = "@@" + i[0] # else: # table_of_contents[title_key].append(i[0]) for i in mulu_list_bool: if i[1] == "一级标题": paper_dan = { "title": "@@" + i[0], "small_title": [], "word_count": 0 } table_of_contents.append(paper_dan) else: table_of_contents[-1]["small_title"].append(i[0]) x_list = [0] y_list = [normal_distribution(0)] gradient = config.zong_gradient / len(table_of_contents) for i in range(len(table_of_contents) - 1): x_gradient = x_list[-1] + gradient x_list.append(x_gradient) y_list.append(normal_distribution(x_list[-1])) dan_gradient = config.paper_word_count / sum(y_list) for i in range(len(y_list)): table_of_contents[i]["word_count"] = dan_gradient * y_list[i] print(table_of_contents) print(len(table_of_contents)) table_of_contents_new = [] for dabiaoti_index in range(len(table_of_contents)): dabiaoti_dict = table_of_contents[dabiaoti_index] table_of_contents_new.append([dabiaoti_dict["title"], 0]) for xiaobiaoti in dabiaoti_dict["small_title"]: table_of_contents_new.append( [xiaobiaoti, int(dabiaoti_dict["word_count"] / len(dabiaoti_dict["small_title"]))]) small_task_list = [] # api_key, # index, # title, # mulu, # subtitle, # prompt kaitibaogao_task = { "task_type": "kaitibaogao", "uuid": uuid, "main_parameter": [title] } chat_abstract_task = { "task_type": "chat_abstract", "uuid": uuid, "main_parameter": [title] } small_task_list.append(kaitibaogao_task) small_task_list.append(chat_abstract_task) content_index = 0 while True: if content_index == len(table_of_contents_new): break subtitle, word_count = table_of_contents_new[content_index] prompt = config.small_title_prompt print(table_of_contents_new[1][0]) if content_index == 0 and table_of_contents_new[1][0][:2] == "@@" and subtitle[:2] == "@@": subtitle, prompt, word_count = subtitle[2:], config.first_title_prompt, 800 if content_index == len(table_of_contents_new) - 1 and subtitle[:2] == "@@": subtitle, prompt, word_count = subtitle[2:], config.first_title_prompt, 800 print("请求的所有参数", content_index, title, subtitle, prompt, word_count) paper_content = { "task_type": "paper_content", "uuid": uuid, "main_parameter": [ content_index, title, mulu, subtitle, prompt, word_count ] } small_task_list.append(paper_content) content_index += 1 thanks_task = { "task_type": "thanks_task", "uuid": uuid, "main_parameter": [ title, config.thank_prompt ] } references_task = { "task_type": "references_task", "uuid": uuid, "main_parameter": [ title, mulu, config.references_prompt ] } small_task_list.append(thanks_task) small_task_list.append(references_task) res = { "num_small_task": len(small_task_list), "tasking_num": 0, "标题": title, "目录": mulu, "开题报告": "", "任务书": "", "中文摘要": "", "英文摘要": "", "中文关键词": "", "英文关键词": "", "正文": "", "致谢": "", "参考文献": "", "table_of_contents": [""] * len(table_of_contents_new) } for small_task in small_task_list: result, task_type = small_title_tesk(small_task) if task_type == "kaitibaogao": res["开题报告"] = result elif task_type == "chat_abstract": for i in result: res[i] = result[i] elif task_type == "paper_content": content_index = small_task["main_parameter"][0] res["table_of_contents"][content_index] = result elif task_type == "thanks_task": res["致谢"] = result elif task_type == "references_task": res["参考文献"] = result return res def classify(): # 调用模型,设置最大batch_size while True: if redis_.llen(redis_title) == 0: # 若队列中没有元素就继续获取 time.sleep(3) continue query = redis_.lpop(redis_title).decode('UTF-8') # 获取query的text query = json.loads(query) uuid = query['id'] texts = query["text"] response = main_prrcess(texts) print("res", response) return_text = str({"texts": response, "probabilities": None, "status_code": 200}) uuid_path = os.path.join(config.project_data_txt_path, uuid) os.makedirs(uuid_path) paper_content_path = os.path.join(uuid_path, "paper_content.json") print(uuid) with open(paper_content_path, "w") as outfile: json.dump(response, outfile) save_word_paper = os.path.join(uuid_path, "paper.docx") save_word_paper_start = os.path.join(uuid_path, "paper_start.docx") os.system( "java -Dfile.encoding=UTF-8 -jar '/home/majiahui/projert/chatglm/aiXieZuoPro.jar' '{}' '{}' '{}'".format( paper_content_path, save_word_paper, save_word_paper_start)) redis_.set(uuid, return_text, 28800) @app.route("/predict", methods=["POST"]) def handle_query(): print(request.remote_addr) texts = request.json["texts"] if texts is None: return_text = {"texts": "输入了空值", "probabilities": None, "status_code": 402} return jsonify(return_text) id_ = str(uuid.uuid1()) # 为query生成唯一标识 d = {'id': id_, 'text': texts} # 绑定文本和query id redis_.rpush(redis_title, json.dumps(d)) # 加入redis while True: result = redis_.get(id_) # 获取该query的模型结果 if result is not None: result_text = {'code': "200", 'data': result.decode('UTF-8')} break else: time.sleep(1) return jsonify(result_text) # 返回结果 t = Thread(target=classify) t.start() if __name__ == "__main__": fh = logging.FileHandler(mode='a', encoding='utf-8', filename='chitchat.log') logging.basicConfig( handlers=[fh], level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%a, %d %b %Y %H:%M:%S', ) app.run(host="0.0.0.0", port=15000, threaded=True, debug=False)