You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
202 lines
7.9 KiB
202 lines
7.9 KiB
![]()
2 years ago
|
|
||
|
import os
|
||
|
import json
|
||
|
import re
|
||
|
import math
|
||
|
import numpy as np
|
||
|
from tqdm import tqdm
|
||
|
import random
|
||
|
|
||
|
path_output = "chatgpt_data_v1"
|
||
|
patten = "目录是“(.*)”,请把其中的"
|
||
|
pantten_biaoti = '[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
|
||
|
thanks = "致谢"
|
||
|
references = "参考文献"
|
||
|
excursus = "附录"
|
||
|
RE_CHINA_NUMS = "[一二三四五六七八九].?.?总结|[1-9].?.?总结|[一二三四五六七八九].?.?结论|[1-9].?.?结论"
|
||
|
RE_CHINA_TITLE = "请把其中的小标题“(.*?)”的内容补充完整|请把其中的大标题“(.*?)”的内容补充完整"
|
||
|
|
||
|
data_tongji = {
|
||
|
"0-600": 0,
|
||
|
"600-1500": 0,
|
||
|
"1500-": 0,
|
||
|
}
|
||
|
|
||
|
jishu = 0
|
||
|
|
||
|
def is_contains_chinese(strs):
|
||
|
for _char in strs:
|
||
|
if '\u4e00' <= _char <= '\u9fa5':
|
||
|
return True
|
||
|
return False
|
||
|
# pantten_second_biaoti = '[2二ⅡⅠ][、.]\s{0,}?[\u4e00-\u9fa5]+'
|
||
|
|
||
|
lable_data_amount = {
|
||
|
"title_beijing_prompt_data.txt": {"num_token": -1, "prompt": "生成论文来源的背景#"},
|
||
|
"title_jianjie_prompt_data.txt": {"num_token": -1, "prompt": "生成研究内容#"},
|
||
|
"title_mulu_prompt_data.txt": {"num_token": -1, "prompt": "生成目录#"},
|
||
|
"title_yanjiubeijingyiyi_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的研究背景和意义#"},
|
||
|
"title_zhixie_prompt_data.txt": {"num_token": -1, "prompt": "生成致谢#"},
|
||
|
"title_zongjie_prompt_data.txt": {"num_token": -1, "prompt": "生成论文简短总结#"},
|
||
|
"title_zongshu_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的国内外研究状况综述#"},
|
||
|
"sencend_task_book_prompt_data.txt": {"num_token": -1, "prompt": "生成6点本篇论文应完成的主要内容#"},
|
||
|
"sencend_references_prompt_data.txt": {"num_token": -1, "prompt": "生成参考文献#"},
|
||
|
"sencend_small_title_prompt_shuffle_data.txt": {"num_token": -1, "prompt": "生成论文小标题内容#"},
|
||
|
"sencend_zhaiyao_prompt_data.txt": {"num_token": -1, "prompt": "生成论文摘要#"},
|
||
|
"third_zhaiyao_chinese_keyword_prompt_data.txt": {"num_token": -1, "prompt": "生成关键字#"},
|
||
|
"third_zhaiyao_fanyi_prompt_data.txt": {"num_token": -1, "prompt": "翻译摘要#"},
|
||
|
"fourth_chinese_keyword_en_prompt_data.txt": {"num_token": -1, "prompt": "翻译关键词#"},
|
||
|
"title_hexin_beijing_prompt_data.txt": {"num_token": -1, "prompt": "生成论文来源的背景#"},
|
||
|
"title_hexin_jianjie_prompt_data.txt": {"num_token": -1, "prompt": "生成研究内容#"},
|
||
|
"title_hexin_mulu_prompt_data.txt": {"num_token": -1, "prompt": "生成目录#"},
|
||
|
"title_hexin_yanjiubeijingyiyi_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的研究背景和意义#"},
|
||
|
"title_hexin_zongjie_prompt_data.txt": {"num_token": -1, "prompt": "生成论文简短总结#"},
|
||
|
"title_hexin_zongshu_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的国内外研究状况综述#"},
|
||
|
"title_hexin_zhixie_prompt_data.txt": {"num_token": -1, "prompt": "生成致谢#"}
|
||
|
}
|
||
|
|
||
|
|
||
|
patten_mulu = {
|
||
|
"title_mulu_references_prompt_data.txt": "目录是“(.*)”,请为这篇论文生成15篇左右的参考文献",
|
||
|
"title_mulu_small_title_prompt_shuffle_data_new.txt": "目录是“(.*)”,请把其中的小标题",
|
||
|
"title_mulu_zhaiyao_data.txt": "目录是“(.*)”,生成论文摘要"
|
||
|
}
|
||
|
|
||
|
|
||
|
|
||
|
path_list = []
|
||
|
file = "./data/{}/paper_prompt_title_1".format(path_output)
|
||
|
for root, dirs, files in os.walk(file):
|
||
|
for file in files:
|
||
|
path = os.path.join(root, file)
|
||
|
path_list.append(path)
|
||
|
|
||
|
file = "./data/{}/paper_prompt_title_1_1".format(path_output)
|
||
|
for root, dirs, files in os.walk(file):
|
||
|
for file in files:
|
||
|
path = os.path.join(root, file)
|
||
|
path_list.append(path)
|
||
|
|
||
|
file = "./data/{}/paper_prompt_title_1_1_1".format(path_output)
|
||
|
for root, dirs, files in os.walk(file):
|
||
|
for file in files:
|
||
|
path = os.path.join(root, file)
|
||
|
path_list.append(path)
|
||
|
|
||
|
file = "./data/{}/paper_prompt_title_1_1_1_1".format(path_output)
|
||
|
for root, dirs, files in os.walk(file):
|
||
|
for file in files:
|
||
|
path = os.path.join(root, file)
|
||
|
path_list.append(path)
|
||
|
|
||
|
file = "./data/{}/paper_prompt_title_1_hexin".format(path_output)
|
||
|
for root, dirs, files in os.walk(file):
|
||
|
for file in files:
|
||
|
path = os.path.join(root, file)
|
||
|
path_list.append(path)
|
||
|
|
||
|
text_list_new = []
|
||
|
tongji = {}
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
data_list = []
|
||
|
train_list = []
|
||
|
new_data_list = []
|
||
|
for path in path_list:
|
||
|
patten_mulu_bool = False
|
||
|
shuminghao_bool = False
|
||
|
patten_mulu_patten = ""
|
||
|
task_name = path.split("\\")[-1]
|
||
|
|
||
|
if task_name in patten_mulu:
|
||
|
patten_mulu_bool = True
|
||
|
patten_mulu_patten = patten_mulu[task_name]
|
||
|
|
||
|
train_data_amount_dict = lable_data_amount[task_name]
|
||
|
train_data_amount = train_data_amount_dict["num_token"]
|
||
|
prompt = train_data_amount_dict["prompt"]
|
||
|
|
||
|
with open(path, encoding="utf-8") as f:
|
||
|
text = f.read()
|
||
|
text_list = text.split("@" * 20)
|
||
|
|
||
|
for data_dan in text_list:
|
||
|
if "*" * 20 in data_dan:
|
||
|
content, summary = data_dan.split("*" * 20)
|
||
|
# text_list_new.append(data_dan)
|
||
|
# data_dan = data_dan.replace("\\n", "\n").replace("\n", "\\n")
|
||
|
|
||
|
new_data_list.append(
|
||
|
{
|
||
|
"input": str(content).strip("\"").strip("\n").strip("\"").replace("\\n", "\n"),
|
||
|
"output": str(summary).replace("\\n", "\n").strip("\""),
|
||
|
"instruction": prompt
|
||
|
}
|
||
|
)
|
||
|
|
||
|
if task_name not in tongji:
|
||
|
tongji[task_name] = 1
|
||
|
else:
|
||
|
tongji[task_name] += 1
|
||
|
else:
|
||
|
continue
|
||
|
|
||
|
for data_dan in tqdm(new_data_list):
|
||
|
zishu_query = len(data_dan["input"])
|
||
|
zishu_response = len(data_dan["output"])
|
||
|
|
||
|
prompt = str(data_dan["instruction"]).replace("\\n", "\n")
|
||
|
query = data_dan["input"].replace("\\n", "\n")
|
||
|
response = data_dan["output"].replace("\\n", "\n")
|
||
|
|
||
|
if prompt == "翻译摘要#":
|
||
|
zishu_summary = len(response.split(" "))
|
||
|
elif prompt == "翻译关键词#":
|
||
|
zishu_summary = len(response.split(" "))
|
||
|
else:
|
||
|
bool_ = is_contains_chinese(response)
|
||
|
if bool_ == False:
|
||
|
print(data_dan)
|
||
|
continue
|
||
|
|
||
|
if "生成方向" in query:
|
||
|
query = query.replace("生成方向","研究方向")
|
||
|
if "生成方向" in response:
|
||
|
response = response.replace("生成方向", "研究方向")
|
||
|
|
||
|
if prompt == "生成论文小标题内容#":
|
||
|
query_re = re.findall(RE_CHINA_TITLE, query)
|
||
|
if "总结" not in query_re[0] or "结论" not in query_re[0]:
|
||
|
response_re = re.findall(RE_CHINA_NUMS, response)
|
||
|
if response_re != []:
|
||
|
jishu += 1
|
||
|
continue
|
||
|
|
||
|
if prompt[-1] != "\n":
|
||
|
prompt += "\n"
|
||
|
if query[-1] != "\n":
|
||
|
query += "\n"
|
||
|
query = "问:" + query + "答:\n"
|
||
|
|
||
|
if len(query) < 700 and len(response) < 1400:
|
||
|
data_list.append({
|
||
|
"instruction": prompt,
|
||
|
"input": query,
|
||
|
"output": response
|
||
|
})
|
||
|
|
||
|
train_nums = int(len(data_list) * 0.95)
|
||
|
dev_nums = int(len(data_list) * 0.05)
|
||
|
#
|
||
|
random.shuffle(data_list)
|
||
|
print(train_nums)
|
||
|
print(dev_nums)
|
||
|
train_list = data_list[:train_nums]
|
||
|
dev_list = data_list[train_nums:]
|
||
|
with open("./data/llama_t/chatglm_train_4_prompt_llama.json", mode="w", encoding="utf-8") as f:
|
||
|
f.write(json.dumps(train_list, ensure_ascii=False, indent=2))
|
||
|
|
||
|
with open("./data/llama_t/chatglm_dev_4_prompt_llama.json", mode="w", encoding="utf-8") as f:
|
||
|
f.write(json.dumps(dev_list, ensure_ascii=False, indent=2))
|