From 36049bbe77e3f456b2a9c50c4eb27d22e7b6f250 Mon Sep 17 00:00:00 2001 From: "majiahui@haimaqingfan.com" Date: Mon, 17 Jul 2023 12:18:37 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E6=95=B0=E6=8D=AE=E6=95=B4?= =?UTF-8?q?=E5=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 训练数据整合.py | 312 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 312 insertions(+) create mode 100644 训练数据整合.py diff --git a/训练数据整合.py b/训练数据整合.py new file mode 100644 index 0000000..db92f82 --- /dev/null +++ b/训练数据整合.py @@ -0,0 +1,312 @@ + +import os +import json +import re +import math +import numpy as np +from tqdm import tqdm + + +patten = "目录是“(.*)”,请把其中的" +pantten_biaoti = '[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+' +thanks = "致谢" +references = "参考文献" +excursus = "附录" + +def ulit_mulu(mulu): + + # 生成参考文件的提示文本 + + table_of_contents = [] + if "\\n" in mulu: + mulu_list = str(mulu).split("\\n") + elif "\n" in mulu: + mulu_list = str(mulu).split("\n") + else: + return False, "" + if mulu_list[0] == "目录": + mulu_list = mulu_list[1:] + mulu_list = [i.strip() for i in mulu_list if i != ""] + + mulu_list_bool = [] + for i in mulu_list: + result_biaoti_list = re.findall(pantten_biaoti, i) + if result_biaoti_list != []: + mulu_list_bool.append((i, "一级标题")) + else: + mulu_list_bool.append((i, "二级标题")) + + mulu_list_bool_part = mulu_list_bool[:3] + + if mulu_list_bool_part[0][1] != "一级标题": + return False, "" + if mulu_list_bool_part[0][1] == mulu_list_bool_part[1][1] == mulu_list_bool_part[2][1] == "一级标题": + return False, "" + + thanks_references_bool_table = mulu_list_bool[-5:] + + for i in thanks_references_bool_table: + try: + if references in i[0]: + mulu_list_bool.remove(i) + if thanks in i[0]: + mulu_list_bool.remove(i) + if excursus in i[0]: + mulu_list_bool.remove(i) + except: + + print(thanks_references_bool_table) + return False, "" + + for i in mulu_list_bool: + if i[1] == "一级标题": + paper_dan = { + "title": "@@" + i[0], + "small_title": [], + "word_count": 0 + } + table_of_contents.append(paper_dan) + else: + table_of_contents[-1]["small_title"].append(i[0]) + + is_contine = False + + for big_title in table_of_contents[:-1]: + if len(big_title["small_title"]) < 2 or len(big_title["small_title"]) > 5: + is_contine = True + break + if is_contine == True: + return False, "" + + # print(table_of_contents) + # + # print(len(table_of_contents)) + + table_of_contents_new = [] + for dabiaoti_index in range(len(table_of_contents)): + dabiaoti_dict = table_of_contents[dabiaoti_index] + dan_str_list = [dabiaoti_dict["title"][2:]] + dabiaoti_dict["small_title"] + dan_str = "\n".join(dan_str_list) + table_of_contents_new.append(dan_str) + + mulu_txt = "\n\n".join(table_of_contents_new) + mulu_txt = mulu_txt.replace("\n", "\\n") + return True, mulu_txt + + +def is_contains_chinese(strs): + for _char in strs: + if '\u4e00' <= _char <= '\u9fa5': + return True + return False +# pantten_second_biaoti = '[2二ⅡⅠ][、.]\s{0,}?[\u4e00-\u9fa5]+' + +lable_data_amount = { + "title_beijing_prompt_data.txt": {"num_token": -1, "prompt": "生成论文来源的背景#"}, + "title_jianjie_prompt_data.txt": {"num_token": -1, "prompt": "生成研究内容#"}, + "title_yanjiubeijingyiyi_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的研究背景和意义#"}, + "title_zongjie_prompt_data.txt": {"num_token": -1, "prompt": "生成论文简短总结#"}, + "title_zongshu_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的国内外研究状况综述#"}, + "jianjie_task_book_prompt_data.txt": {"num_token": -1, "prompt": "生成6点本篇论文应完成的主要内容#"}, + "title_mulu_references_prompt_data.txt": {"num_token": 1, "prompt": "生成参考文献#"}, + "title_mulu_small_title_prompt_shuffle_data_new.txt": {"num_token": -1, "prompt": "生成论文小标题内容#"}, + "title_mulu_zhaiyao_data.txt": {"num_token": -1, "prompt": "生成论文摘要#"}, + "zhaiyao_chinese_keyword_prompt_data.txt": {"num_token": -1, "prompt": "生成关键字#"}, + "zhaiyao_fanyi_prompt_data.txt": {"num_token": -1, "prompt": "翻译摘要#"}, + "chinese_keyword_en_prompt_data.txt": {"num_token": -1, "prompt": "翻译关键词#"}, + "title_hexin_beijing_prompt_data.txt": {"num_token": -1, "prompt": "生成论文来源的背景#"}, + "title_hexin_jianjie_prompt_data.txt": {"num_token": -1, "prompt": "生成研究内容#"}, + "title_hexin_mulu_prompt_data.txt": {"num_token": -1, "prompt": "生成目录#"}, + "title_hexin_yanjiubeijingyiyi_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的研究背景和意义#"}, + "title_hexin_zongjie_prompt_data.txt": {"num_token": -1, "prompt": "生成论文简短总结#"}, + "title_hexin_zongshu_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的国内外研究状况综述#"} +} + +re_file = { + "title_beijing_prompt_data.txt": "\n以“", + "title_jianjie_prompt_data.txt": "\n请帮我生成《", + "title_yanjiubeijingyiyi_prompt_data.txt": "\n请分别写出以《", + "title_zongjie_prompt_data.txt": "\n以“", + "title_zongshu_prompt_data.txt": "\n请写出以《", + "jianjie_task_book_prompt_data.txt": "\n\"请根据题目为《", + "title_mulu_references_prompt_data.txt": "\n\"论文题目是“", + "zhaiyao_chinese_keyword_prompt_data.txt": "\n\"请为“", + "zhaiyao_fanyi_prompt_data.txt": "\n\"请把“", + "chinese_keyword_en_prompt_data.txt": "\n\"请把“", + "title_mulu_zhaiyao_data.txt": "@@@@@@@@@@@@@@@@@@", + "title_mulu_small_title_prompt_shuffle_data_new.txt": "@@@@@@@@@@@@@@@@@@", + "title_hexin_beijing_prompt_data.txt": "@@@@@@@@@@@@@@@@@@@@@@@", + "title_hexin_jianjie_prompt_data.txt": "@@@@@@@@@@@@@@@@@@@@@@@", + "title_hexin_mulu_prompt_data.txt": "@@@@@@@@@@@@@@@@@@@@@@@", + "title_hexin_yanjiubeijingyiyi_prompt_data.txt": "@@@@@@@@@@@@@@@@@@@@@@@", + "title_hexin_zongjie_prompt_data.txt": "@@@@@@@@@@@@@@@@@@@@@@@", + "title_hexin_zongshu_prompt_data.txt": "@@@@@@@@@@@@@@@@@@@@@@@" +} + +split_teshu = [ + "title_mulu_zhaiyao_data.txt", + "title_mulu_small_title_prompt_shuffle_data_new.txt", + "title_hexin_beijing_prompt_data.txt", + "title_hexin_jianjie_prompt_data.txt", + "title_hexin_mulu_prompt_data.txt", + "title_hexin_yanjiubeijingyiyi_prompt_data.txt", + "title_hexin_zongjie_prompt_data.txt", + "title_hexin_zongshu_prompt_data.txt" + ] + +patten_mulu = { + "title_mulu_references_prompt_data.txt": "目录是“(.*)”,请为这篇论文生成15篇左右的参考文献", + "title_mulu_small_title_prompt_shuffle_data_new.txt": "目录是“(.*)”,请把其中的小标题", + "title_mulu_zhaiyao_data.txt": "目录是“(.*)”,生成论文摘要" +} + +shuminghao = { + "title_beijing_prompt_data.txt": [("以“","以《"),("”为论文题目","》为论文题目")], + "title_hexin_beijing_prompt_data.txt": [("以“","以《"),("”为论文题目","》为论文题目")], + "title_hexin_mulu_prompt_data.txt": [("论文题目为“","论文题目为《"),("”,以“","》,以“")], # 论文题目为“关于《金子美玲童谣全集》中起点文本向目标文本转换的研究”,以“ + "title_hexin_zongjie_prompt_data.txt":[("以“","以《"),("”为论文题目","》为论文题目")], # 以“面向海量微服务的高可用服务注册中心的研究与实现”为论文题目 + "title_mulu_small_title_prompt_shuffle_data_new.txt": [("论文题目是“", "论文题目是《"), ("”,目录是", "》,目录是")], # "论文题目是“八十年代审美非功利思潮研究”,目录是 + "title_mulu_zhaiyao_data.txt": [("论文题目是“", "论文题目是《"), ("”,目录是", "》,目录是")], # "论文题目是“网络媒体报道对房地产市场的影响研究”,目录是 + "title_zongjie_prompt_data.txt": [("以“", "以《"), ("”为论文题目", "》为论文题目")] # 以“网中人:论哈金《等待》中的伦理困境”为论文题目 +} + + +path_list = [] +file = "./data/paper_prompt_title_3" +for root, dirs, files in os.walk(file): + for file in files: + path = os.path.join(root, file) + path_list.append(path) + +file = "./data/paper_prompt_title_3_1" +for root, dirs, files in os.walk(file): + for file in files: + path = os.path.join(root, file) + path_list.append(path) + +file = "./data/paper_prompt_title_3_1_1" +for root, dirs, files in os.walk(file): + for file in files: + path = os.path.join(root, file) + path_list.append(path) + +file = "./data/paper_prompt_title_hexin_3" +for root, dirs, files in os.walk(file): + for file in files: + path = os.path.join(root, file) + path_list.append(path) + + +text_list_new = [] + +tongji = {} + + +for path in path_list: + patten_mulu_bool = False + shuminghao_bool = False + new_data_list = [] + patten_mulu_patten = "" + shuminghao_list = "" + + task_name = path.split("\\")[-1] + if task_name in re_file: + spilt_dan = re_file[task_name] + else: + continue + + if task_name in patten_mulu: + patten_mulu_bool = True + patten_mulu_patten = patten_mulu[task_name] + + if task_name in shuminghao: + shuminghao_bool = True + shuminghao_list = shuminghao[task_name] + + + train_data_amount_dict = lable_data_amount[task_name] + train_data_amount = train_data_amount_dict["num_token"] + + with open(path, encoding="utf-8") as f: + text = f.read() + text_list = text.split(spilt_dan) + index = 1 + + if train_data_amount == -1: + train_data_amount = len(text_list) -1 + while True: + if index >= train_data_amount: + break + data_dan = text_list[index] + if "**************" in data_dan: + if task_name in split_teshu: + data_dan = data_dan + else: + data_dan = spilt_dan[1:] + data_dan + + # text_list_new.append(data_dan) + if patten_mulu_bool == True: + content, summary = data_dan.split("**************") + result_biaoti_list = re.findall(patten_mulu_patten, content) + try: + mulu = str(result_biaoti_list[0]) + except: + index += 1 + continue + + bool_, mulu_new = ulit_mulu(mulu) + if bool_ == True: + content = content.replace(mulu,mulu_new) + data_dan = "**************".join([content, summary]) + data_dan = data_dan.replace("\\n", "\n").replace("\n", "\\n") + else: + index += 1 + continue + else: + data_dan = data_dan.replace("\\n", "\n").replace("\n", "\\n") + + if shuminghao_bool == True: + content, summary = data_dan.split("**************") + for rep in shuminghao_list: + content = content.replace(rep[0], rep[1]) + data_dan = "**************".join([content, summary]) + + new_data_list.append(data_dan) + index += 1 + + if task_name not in tongji: + tongji[task_name] = 1 + else: + tongji[task_name] += 1 + else: + index += 4 + print(data_dan) + + print(task_name, "\n") + if new_data_list!= []: + print(new_data_list[0]) + with open(f"./data/训练数据集合/{task_name}", mode="w", encoding="utf-8") as f: + for i in new_data_list: + f.write(i) + f.write("\n") + +# train_list.append({"content": str(title_p), "summary": str(b)}) + +train_list = [] +for text, prompt in text_list_new: + content, summary = text.split("**************") + train_list.append( + {"query": str(content).strip("\"").strip("\n").strip("\""), "response": str(summary), "prompt": prompt} + ) + +import random +random.shuffle(train_list) + + +for i in tongji: + print(i, tongji[i]) +with open("./data/chatglm_paper_data_2_prompt.txt", mode="w", encoding="utf-8") as f: + for i in train_list: + f.write(json.dumps(i, ensure_ascii=False)) + f.write("\n") +