You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
6.2 KiB
201 lines
6.2 KiB
import json
|
|
import re
|
|
import math
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
# pantten_second_biaoti = '[2二ⅡⅠ][、.]\s{0,}?[\u4e00-\u9fa5]+'
|
|
pantten_biaoti = '[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
|
|
first_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的大标题“{}”的内容补充完整,补充内容字数在{}字左右"
|
|
small_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的小标题“{}”的内容补充完整,补充内容字数在{}字左右"
|
|
thanks = "致谢"
|
|
references = "参考文献"
|
|
excursus = "附录"
|
|
u = 3.5 # 均值μ
|
|
sig = math.sqrt(6.0)
|
|
zong_gradient = 6
|
|
paper_word_count = 12000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
path = "../data/paper_prompt_title_3/title_mulu_prompt_data.txt"
|
|
with open(path, encoding="utf-8") as f:
|
|
text = f.read()
|
|
|
|
|
|
def normal_distribution(x):
|
|
y = np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (math.sqrt(2 * math.pi) * sig)
|
|
return y
|
|
|
|
text_list = text.split("\n++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n")
|
|
|
|
ner_lable = []
|
|
text_zong = []
|
|
|
|
train_list = []
|
|
|
|
for text_dan in tqdm(text_list):
|
|
# print(text_dan)
|
|
try:
|
|
title, mulu = text_dan.split("**********************************************")
|
|
except:
|
|
continue
|
|
title = str(title).strip("\n")
|
|
mulu = str(mulu).strip("\n")
|
|
paper_text = "题目:{}@目录:".format(title)
|
|
table_of_contents = []
|
|
nerlable_list = []
|
|
|
|
# mulu_base64 = base64.b64encode(mulu.encode('utf-8'))
|
|
# mulu_path = os.path.join(uuid_path, "mulu.txt")
|
|
# with open(mulu_path, 'wb', encoding='utf8') as f2:
|
|
# f2.write(mulu_base64)
|
|
mulu_list = str(mulu).split("\n")
|
|
mulu_list = [i.strip() for i in mulu_list if i != ""]
|
|
mulu_str = "@".join(mulu_list)
|
|
|
|
mulu_list_bool = []
|
|
for i in mulu_list:
|
|
result_biaoti_list = re.findall(pantten_biaoti, i)
|
|
if result_biaoti_list != []:
|
|
mulu_list_bool.append((i, "一级标题"))
|
|
else:
|
|
mulu_list_bool.append((i, "二级标题"))
|
|
|
|
mulu_list_bool_part = mulu_list_bool[:3]
|
|
|
|
if mulu_list_bool_part[0][1] != "一级标题":
|
|
continue
|
|
if mulu_list_bool_part[0][1] == mulu_list_bool_part[1][1] == mulu_list_bool_part[2][1] == "一级标题":
|
|
continue
|
|
|
|
thanks_references_bool_table = mulu_list_bool[-5:]
|
|
|
|
for i in thanks_references_bool_table:
|
|
try:
|
|
if references in i[0]:
|
|
mulu_list_bool.remove(i)
|
|
if thanks in i[0]:
|
|
mulu_list_bool.remove(i)
|
|
if excursus in i[0]:
|
|
mulu_list_bool.remove(i)
|
|
except:
|
|
|
|
print(thanks_references_bool_table)
|
|
continue
|
|
|
|
for i in mulu_list_bool:
|
|
if i[1] == "一级标题":
|
|
paper_dan = {
|
|
"title": "@@" + i[0],
|
|
"small_title": [],
|
|
"word_count": 0
|
|
}
|
|
table_of_contents.append(paper_dan)
|
|
else:
|
|
table_of_contents[-1]["small_title"].append(i[0])
|
|
|
|
x_list = [0]
|
|
y_list = [normal_distribution(0)]
|
|
|
|
gradient = zong_gradient/len(table_of_contents)
|
|
for i in range(len(table_of_contents)-1):
|
|
x_gradient = x_list[-1] + gradient
|
|
x_list.append(x_gradient)
|
|
y_list.append(normal_distribution(x_list[-1]))
|
|
|
|
dan_gradient = paper_word_count/sum(y_list)
|
|
|
|
for i in range(len(y_list)):
|
|
table_of_contents[i]["word_count"] = dan_gradient * y_list[i]
|
|
|
|
# print(table_of_contents)
|
|
#
|
|
# print(len(table_of_contents))
|
|
|
|
table_of_contents_new = []
|
|
for dabiaoti_index in range(len(table_of_contents)):
|
|
dabiaoti_dict = table_of_contents[dabiaoti_index]
|
|
table_of_contents_new.append([dabiaoti_dict["title"], 0])
|
|
for xiaobiaoti in dabiaoti_dict["small_title"]:
|
|
table_of_contents_new.append([xiaobiaoti, int(dabiaoti_dict["word_count"]/len(dabiaoti_dict["small_title"]))])
|
|
|
|
small_task_list = []
|
|
content_index = 0
|
|
while True:
|
|
if content_index == len(table_of_contents_new):
|
|
break
|
|
subtitle, word_count = table_of_contents_new[content_index]
|
|
prompt = small_title_prompt
|
|
|
|
if content_index == 0 and table_of_contents_new[1][0][:2] == "@@" and subtitle[:2] == "@@":
|
|
subtitle, prompt, word_count = subtitle[2:], first_title_prompt, 800
|
|
|
|
if content_index == len(table_of_contents_new) -1 and subtitle[:2] == "@@":
|
|
subtitle, prompt, word_count = subtitle[2:], first_title_prompt, 800
|
|
|
|
paper_content = [
|
|
content_index,
|
|
title,
|
|
mulu,
|
|
subtitle,
|
|
prompt,
|
|
word_count
|
|
]
|
|
|
|
small_task_list.append(paper_content)
|
|
content_index += 1
|
|
|
|
for i in small_task_list:
|
|
if i[3][:2] == "@@":
|
|
continue
|
|
elif i[5] > 1280:
|
|
continue
|
|
else:
|
|
paper_prompt = i[4].format(i[1], i[2], i[3], i[5])
|
|
if len(paper_prompt) < 768:
|
|
train_list.append(paper_prompt)
|
|
else:
|
|
continue
|
|
|
|
import random
|
|
|
|
random.shuffle(train_list)
|
|
|
|
train_list_shuffle = train_list[:100000]
|
|
with open("../data/title_to_/prompt.txt", mode="w", encoding="utf-8") as f:
|
|
for i in train_list:
|
|
f.write(json.dumps(i, ensure_ascii=False))
|
|
f.write("\n")
|
|
|
|
with open("../data/title_to_/prompt_shuffle.txt", mode="w", encoding="utf-8") as f:
|
|
for i in train_list_shuffle:
|
|
f.write(json.dumps(i, ensure_ascii=False))
|
|
f.write("\n")
|
|
|
|
|
|
# for lable in table_of_contents:
|
|
# text_len = len(paper_text)
|
|
# dan_nerlable = [text_len, text_len + len(lable[0]), lable[1]]
|
|
# nerlable_list.append(dan_nerlable)
|
|
# paper_text += lable[0]
|
|
# paper_text += "@"
|
|
#
|
|
# paper_dan = {"text": paper_text, "label": nerlable_list}
|
|
#
|
|
# ner_lable.append(str(table_of_contents))
|
|
# text_zong.append(paper_dan)
|
|
#
|
|
# with open("../data/train.txt", mode="w", encoding="utf-8") as f:
|
|
# for i in text_zong:
|
|
# f.write(json.dumps(i, ensure_ascii=False))
|
|
# f.write("\n")
|
|
#
|
|
#
|
|
# with open("../data/train_lable.txt", mode="w") as f:
|
|
# for i in ner_lable:
|
|
# f.write(json.dumps(i, ensure_ascii=False))
|
|
# f.write("\n")
|
|
|