数据处理代码,为了生成chatgpt数据
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

75 lines
2.7 KiB

2 years ago
import json
import re
pantten_mulu = "目录是“(.*?)”,请把其中"
pantten_title = "“(.*?)”,目录是"
pantten_small_title = "请把其中的小标题“(.*?)”的内容补充完整"
pantten_big_title = "请把其中的大标题“(.*?)”的内容补充完整"
pantten_zishu = "的内容补充完整,补充内容字数在(.*?)字左右"
with open("data/prompt_small_gen.txt", encoding="utf-8") as f:
content = f.read()
content_list = content.split("\"论文题目是")
content_list = content_list[1:]
content_list = [i.strip("\n") for i in content_list]
train = []
print(len(content_list))
for i in content_list:
result_biaoti_list = re.findall(pantten_mulu, i)
try:
result_biaoti_list[0]
except:
print(i)
continue
if result_biaoti_list[0] != "":
mulu_list = str(result_biaoti_list[0]).split("\\n")
mulu_list = [i.strip() for i in mulu_list if i != ""]
mulu = "@".join(mulu_list)
else:
continue
result_biaoti_list = re.findall(pantten_title, i)
if result_biaoti_list[0] != "":
title = result_biaoti_list[0]
else:
continue
result_biaoti_small_list = re.findall(pantten_small_title, i)
result_biaoti_big_list = re.findall(pantten_big_title, i)
if result_biaoti_small_list != []:
small_title = result_biaoti_small_list[0]
result_biaoti_list = re.findall(pantten_zishu, i)
if result_biaoti_list[0] != "":
zishu = result_biaoti_list[0]
else:
continue
small_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的小标题“{}”的内容补充完整,补充内容字数在{}字左右"
neirong = i.split("**************")[1]
a = small_title_prompt.format(title, mulu, small_title, zishu)
if len(str(a)) + len(str(neirong))< 2048:
train.append({"content": str(a), "summary": str(neirong)})
elif result_biaoti_big_list != []:
big_title = result_biaoti_big_list[0]
result_biaoti_list = re.findall(pantten_zishu, i)
if result_biaoti_list[0] != "":
zishu = result_biaoti_list[0]
else:
continue
big_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的小标题“{}”的内容补充完整,补充内容字数在{}字左右"
neirong = i.split("**************")[1]
a = big_title_prompt.format(title, mulu, big_title, zishu)
if len(str(neirong)) + len(str(a)) < 2048:
train.append({"content": str(a), "summary": str(neirong)})
else:
continue
with open("data/small_title_train.json", "w", encoding="utf-8") as f:
for i in train:
f.write(json.dumps(i, ensure_ascii=False))
f.write("\n")