import json
import re
import math
import numpy as np
from tqdm import tqdm

# pantten_second_biaoti = '[2二ⅡⅠ][、.]\s{0,}?[\u4e00-\u9fa5]+'
pantten_biaoti = '[1-9一二三四五六七八九ⅠⅡⅢⅣⅤⅥⅦⅧⅨ][、.]\s{0,}?[\u4e00-\u9fa5a-zA-Z]+'
first_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的大标题“{}”的内容补充完整,补充内容字数在{}字左右"
small_title_prompt = "论文题目是“{}”,目录是“{}”,请把其中的小标题“{}”的内容补充完整,补充内容字数在{}字左右"
references_prompt = "论文题目是“{}”,目录是“{}”,请为这篇论文生成15篇左右的参考文献,要求其中有有中文参考文献不低于12篇,英文参考文献不低于2篇"
thanks = "致谢"
references = "参考文献"
excursus = "附录"
u = 3.5  # 均值μ
sig = math.sqrt(6.0)
zong_gradient = 6
paper_word_count = 12000
pantten_title = "(.*?)”生成目录,要求只有一级标题和二级标题,"






path = "./data/paper_prompt_title_3/title_mulu_prompt_data.txt"
with open(path, encoding="utf-8") as f:
    text = f.read()


def normal_distribution(x):
    y = np.exp(-(x - u) ** 2 / (2 * sig ** 2)) / (math.sqrt(2 * math.pi) * sig)
    return y

text_list = text.split("为论文题目“")

ner_lable = []
text_zong = []

train_list = []
train_references_list = []

for text_dan in tqdm(text_list):
    # print(text_dan)
    try:
        title_prompt, mulu = text_dan.split("**************")
    except:
        continue
    result_biaoti_list = re.findall(pantten_title, title_prompt)
    try:
        result_biaoti_list[0]
    except:
        print(title_prompt)
        continue

    title = str(result_biaoti_list[0]).strip("\n")
    mulu = str(mulu).strip("\n")

    # 生成参考文件的提示文本
    train_references_list.append(references_prompt.format(title, mulu))

    paper_text = "题目:{}@目录:".format(title)
    table_of_contents = []
    nerlable_list = []

    # mulu_base64 = base64.b64encode(mulu.encode('utf-8'))
    # mulu_path = os.path.join(uuid_path, "mulu.txt")
    # with open(mulu_path, 'wb', encoding='utf8') as f2:
    #     f2.write(mulu_base64)
    mulu_list = str(mulu).split("\n")
    mulu_list = [i.strip() for i in mulu_list if i != ""]
    mulu_str = "@".join(mulu_list)

    mulu_list_bool = []
    for i in mulu_list:
        result_biaoti_list = re.findall(pantten_biaoti, i)
        if result_biaoti_list != []:
            mulu_list_bool.append((i, "一级标题"))
        else:
            mulu_list_bool.append((i, "二级标题"))

    mulu_list_bool_part = mulu_list_bool[:3]

    if mulu_list_bool_part[0][1] != "一级标题":
        continue
    if mulu_list_bool_part[0][1] == mulu_list_bool_part[1][1] == mulu_list_bool_part[2][1] == "一级标题":
        continue

    thanks_references_bool_table = mulu_list_bool[-5:]

    for i in thanks_references_bool_table:
        try:
            if references in i[0]:
                mulu_list_bool.remove(i)
            if thanks in i[0]:
                mulu_list_bool.remove(i)
            if excursus in i[0]:
                mulu_list_bool.remove(i)
        except:

            print(thanks_references_bool_table)
            continue

    for i in mulu_list_bool:
        if i[1] == "一级标题":
            paper_dan = {
                "title": "@@" + i[0],
                "small_title": [],
                "word_count": 0
            }
            table_of_contents.append(paper_dan)
        else:
            table_of_contents[-1]["small_title"].append(i[0])

    x_list = [0]
    y_list = [normal_distribution(0)]

    gradient = zong_gradient/len(table_of_contents)
    for i in range(len(table_of_contents)-1):
        x_gradient = x_list[-1] + gradient
        x_list.append(x_gradient)
        y_list.append(normal_distribution(x_list[-1]))

    dan_gradient = paper_word_count/sum(y_list)

    for i in range(len(y_list)):
        table_of_contents[i]["word_count"] = dan_gradient * y_list[i]

    # print(table_of_contents)
    #
    # print(len(table_of_contents))

    table_of_contents_new = []
    for dabiaoti_index in range(len(table_of_contents)):
        dabiaoti_dict = table_of_contents[dabiaoti_index]
        table_of_contents_new.append([dabiaoti_dict["title"], 0])
        for xiaobiaoti in dabiaoti_dict["small_title"]:
            table_of_contents_new.append([xiaobiaoti, int(dabiaoti_dict["word_count"]/len(dabiaoti_dict["small_title"]))])

    small_task_list = []
    content_index = 0
    while True:
        if content_index == len(table_of_contents_new):
            break
        subtitle, word_count = table_of_contents_new[content_index]
        prompt = small_title_prompt

        if content_index == 0 and table_of_contents_new[1][0][:2] == "@@" and subtitle[:2] == "@@":
            subtitle, prompt, word_count = subtitle[2:], first_title_prompt, 800

        if content_index == len(table_of_contents_new) -1 and subtitle[:2] == "@@":
            subtitle, prompt, word_count = subtitle[2:], first_title_prompt, 800

        paper_content = [
            content_index,
            title,
            mulu,
            subtitle,
            prompt,
            word_count
        ]

        small_task_list.append(paper_content)
        content_index += 1

    for i in small_task_list:
        if i[3][:2] == "@@":
            continue
        elif i[5] > 1280:
            continue
        else:
            paper_prompt = i[4].format(i[1], i[2], i[3], i[5])
            if len(paper_prompt) < 768:
                train_list.append(paper_prompt)
            else:
                continue

import random

random.shuffle(train_list)

train_list_shuffle = train_list[:10000]

with open("./data/title_mulu_to_/references_prompt.txt", mode="w", encoding="utf-8") as f:
    for i in train_references_list:
        f.write(json.dumps(i, ensure_ascii=False))
        f.write("\n")

with open("./data/title_mulu_to_/small_title_prompt.txt", mode="w", encoding="utf-8") as f:
    for i in train_list:
        f.write(json.dumps(i, ensure_ascii=False))
        f.write("\n")

with open("./data/title_mulu_to_/small_title_prompt_shuffle.txt", mode="w", encoding="utf-8") as f:
    for i in train_list_shuffle:
        f.write(json.dumps(i, ensure_ascii=False))
        f.write("\n")


#     for lable in table_of_contents:
#         text_len = len(paper_text)
#         dan_nerlable = [text_len, text_len + len(lable[0]), lable[1]]
#         nerlable_list.append(dan_nerlable)
#         paper_text += lable[0]
#         paper_text += "@"
#
#     paper_dan = {"text": paper_text, "label": nerlable_list}
#
#     ner_lable.append(str(table_of_contents))
#     text_zong.append(paper_dan)
#
# with open("../data/train.txt", mode="w", encoding="utf-8") as f:
#     for i in text_zong:
#         f.write(json.dumps(i, ensure_ascii=False))
#         f.write("\n")
#
#
# with open("../data/train_lable.txt", mode="w") as f:
#     for i in ner_lable:
#         f.write(json.dumps(i, ensure_ascii=False))
#         f.write("\n")