数据处理代码,为了生成chatgpt数据
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
7.9 KiB

import os
import json
import re
import math
import numpy as np
from tqdm import tqdm
import random
path_output = "chatgpt_data_v1"
patten = "目录是“(.*)”,请把其中的"
pantten_biaoti = '[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
thanks = "致谢"
references = "参考文献"
excursus = "附录"
RE_CHINA_NUMS = "[一二三四五六七八九].?.?总结|[1-9].?.?总结|[一二三四五六七八九].?.?结论|[1-9].?.?结论"
RE_CHINA_TITLE = "请把其中的小标题“(.*?)”的内容补充完整|请把其中的大标题“(.*?)”的内容补充完整"
data_tongji = {
"0-600": 0,
"600-1500": 0,
"1500-": 0,
}
jishu = 0
def is_contains_chinese(strs):
for _char in strs:
if '\u4e00' <= _char <= '\u9fa5':
return True
return False
# pantten_second_biaoti = '[2二ⅡⅠ][、.]\s{0,}?[\u4e00-\u9fa5]+'
lable_data_amount = {
"title_beijing_prompt_data.txt": {"num_token": -1, "prompt": "生成论文来源的背景#"},
"title_jianjie_prompt_data.txt": {"num_token": -1, "prompt": "生成研究内容#"},
"title_mulu_prompt_data.txt": {"num_token": -1, "prompt": "生成目录#"},
"title_yanjiubeijingyiyi_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的研究背景和意义#"},
"title_zhixie_prompt_data.txt": {"num_token": -1, "prompt": "生成致谢#"},
"title_zongjie_prompt_data.txt": {"num_token": -1, "prompt": "生成论文简短总结#"},
"title_zongshu_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的国内外研究状况综述#"},
"sencend_task_book_prompt_data.txt": {"num_token": -1, "prompt": "生成6点本篇论文应完成的主要内容#"},
"sencend_references_prompt_data.txt": {"num_token": -1, "prompt": "生成参考文献#"},
"sencend_small_title_prompt_shuffle_data.txt": {"num_token": -1, "prompt": "生成论文小标题内容#"},
"sencend_zhaiyao_prompt_data.txt": {"num_token": -1, "prompt": "生成论文摘要#"},
"third_zhaiyao_chinese_keyword_prompt_data.txt": {"num_token": -1, "prompt": "生成关键字#"},
"third_zhaiyao_fanyi_prompt_data.txt": {"num_token": -1, "prompt": "翻译摘要#"},
"fourth_chinese_keyword_en_prompt_data.txt": {"num_token": -1, "prompt": "翻译关键词#"},
"title_hexin_beijing_prompt_data.txt": {"num_token": -1, "prompt": "生成论文来源的背景#"},
"title_hexin_jianjie_prompt_data.txt": {"num_token": -1, "prompt": "生成研究内容#"},
"title_hexin_mulu_prompt_data.txt": {"num_token": -1, "prompt": "生成目录#"},
"title_hexin_yanjiubeijingyiyi_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的研究背景和意义#"},
"title_hexin_zongjie_prompt_data.txt": {"num_token": -1, "prompt": "生成论文简短总结#"},
"title_hexin_zongshu_prompt_data.txt": {"num_token": -1, "prompt": "生成课题的国内外研究状况综述#"},
"title_hexin_zhixie_prompt_data.txt": {"num_token": -1, "prompt": "生成致谢#"}
}
patten_mulu = {
"title_mulu_references_prompt_data.txt": "目录是“(.*)”,请为这篇论文生成15篇左右的参考文献",
"title_mulu_small_title_prompt_shuffle_data_new.txt": "目录是“(.*)”,请把其中的小标题",
"title_mulu_zhaiyao_data.txt": "目录是“(.*)”,生成论文摘要"
}
path_list = []
file = "./data/{}/paper_prompt_title_1".format(path_output)
for root, dirs, files in os.walk(file):
for file in files:
path = os.path.join(root, file)
path_list.append(path)
file = "./data/{}/paper_prompt_title_1_1".format(path_output)
for root, dirs, files in os.walk(file):
for file in files:
path = os.path.join(root, file)
path_list.append(path)
file = "./data/{}/paper_prompt_title_1_1_1".format(path_output)
for root, dirs, files in os.walk(file):
for file in files:
path = os.path.join(root, file)
path_list.append(path)
file = "./data/{}/paper_prompt_title_1_1_1_1".format(path_output)
for root, dirs, files in os.walk(file):
for file in files:
path = os.path.join(root, file)
path_list.append(path)
file = "./data/{}/paper_prompt_title_1_hexin".format(path_output)
for root, dirs, files in os.walk(file):
for file in files:
path = os.path.join(root, file)
path_list.append(path)
text_list_new = []
tongji = {}
if __name__ == '__main__':
data_list = []
train_list = []
new_data_list = []
for path in path_list:
patten_mulu_bool = False
shuminghao_bool = False
patten_mulu_patten = ""
task_name = path.split("\\")[-1]
if task_name in patten_mulu:
patten_mulu_bool = True
patten_mulu_patten = patten_mulu[task_name]
train_data_amount_dict = lable_data_amount[task_name]
train_data_amount = train_data_amount_dict["num_token"]
prompt = train_data_amount_dict["prompt"]
with open(path, encoding="utf-8") as f:
text = f.read()
text_list = text.split("@" * 20)
for data_dan in text_list:
if "*" * 20 in data_dan:
content, summary = data_dan.split("*" * 20)
# text_list_new.append(data_dan)
# data_dan = data_dan.replace("\\n", "\n").replace("\n", "\\n")
new_data_list.append(
{
"input": str(content).strip("\"").strip("\n").strip("\"").replace("\\n", "\n"),
"output": str(summary).replace("\\n", "\n").strip("\""),
"instruction": prompt
}
)
if task_name not in tongji:
tongji[task_name] = 1
else:
tongji[task_name] += 1
else:
continue
for data_dan in tqdm(new_data_list):
zishu_query = len(data_dan["input"])
zishu_response = len(data_dan["output"])
prompt = str(data_dan["instruction"]).replace("\\n", "\n")
query = data_dan["input"].replace("\\n", "\n")
response = data_dan["output"].replace("\\n", "\n")
if prompt == "翻译摘要#":
zishu_summary = len(response.split(" "))
elif prompt == "翻译关键词#":
zishu_summary = len(response.split(" "))
else:
bool_ = is_contains_chinese(response)
if bool_ == False:
print(data_dan)
continue
if "生成方向" in query:
query = query.replace("生成方向","研究方向")
if "生成方向" in response:
response = response.replace("生成方向", "研究方向")
if prompt == "生成论文小标题内容#":
query_re = re.findall(RE_CHINA_TITLE, query)
if "总结" not in query_re[0] or "结论" not in query_re[0]:
response_re = re.findall(RE_CHINA_NUMS, response)
if response_re != []:
jishu += 1
continue
if prompt[-1] != "\n":
prompt += "\n"
if query[-1] != "\n":
query += "\n"
query = "问:" + query + "答:\n"
if len(query) < 700 and len(response) < 1400:
data_list.append({
"instruction": prompt,
"input": query,
"output": response
})
train_nums = int(len(data_list) * 0.95)
dev_nums = int(len(data_list) * 0.05)
#
random.shuffle(data_list)
print(train_nums)
print(dev_nums)
train_list = data_list[:train_nums]
dev_list = data_list[train_nums:]
with open("./data/llama_t/chatglm_train_4_prompt_llama.json", mode="w", encoding="utf-8") as f:
f.write(json.dumps(train_list, ensure_ascii=False, indent=2))
with open("./data/llama_t/chatglm_dev_4_prompt_llama.json", mode="w", encoding="utf-8") as f:
f.write(json.dumps(dev_list, ensure_ascii=False, indent=2))