import os
import re
import random
import json
from tqdm import tqdm

RE_CHINA_NUMS = "[一二三四五六七八九].?.?总结|[1-9].?.?总结|[一二三四五六七八九].?.?结论|[1-9].?.?结论"
RE_CHINA_TITLE = "请把其中的小标题“(.*?)”的内容补充完整|请把其中的大标题“(.*?)”的内容补充完整"

data_tongji = {
    "0-600": 0,
    "600-1500": 0,
    "1500-": 0,
}
# print("这段文字翻译成英文"\n'")
data_tongji_prompt = []

def is_contains_chinese(strs):
    for _char in strs:
        if '\u4e00' <= _char <= '\u9fa5':
            return True
    return False

data_list = []

jishu = 0

with open("data/chatglm_paper_data_2_prompt.txt", encoding="utf-8") as f:
    for i in tqdm(f):
        data_dan = eval(i)
        zishu_query = len(data_dan["query"])
        zishu_response = len(data_dan["response"])

        prompt = str(data_dan["prompt"]).replace("\\n", "\n")
        query = data_dan["query"].replace("\\n", "\n")
        response = data_dan["response"].replace("\\n", "\n")

        if prompt == "翻译摘要#":
            zishu_summary = len(response.split(" "))
        elif prompt == "翻译关键词#":
            zishu_summary = len(response.split(" "))
        else:
            bool_ = is_contains_chinese(response)
            if bool_ == False:
                print(data_dan)
                continue

        if "生成方向" in query:
            query = query.replace("生成方向","研究方向")
        if "生成方向" in response:
            response = response.replace("生成方向", "研究方向")

        if prompt == "生成论文小标题内容#":
            query_re = re.findall(RE_CHINA_TITLE, query)
            if "总结" not in query_re[0] or "结论" not in query_re[0]:
                response_re = re.findall(RE_CHINA_NUMS, response)
                if response_re != []:
                    print(response)
                    print("==========================================================================================")
                    jishu += 1

        if prompt[-1] != "\n":
            prompt += "\n"
        if query[-1] != "\n":
            query += "\n"
        query = "问:" + query + "答:\n"

        if len(query) < 700 and len(response) < 1400:
            data_list.append({
                "instruction": prompt,
                "input": query,
                "output": response
            })
        # if zishu_summary < 600:
        #     data_tongji["0-600"] += 1
        # if 600 < zishu_summary < 1500:
        #     data_tongji["600-1500"] += 1
        # if 1500 < zishu_summary:
        #     data_tongji["1500-"] += 1
        #     data_tongji_prompt.append([data_dan['summary'], zishu_summary])
        # else:
        #     train_list.append(i)


# for i in data_tongji_prompt:
#     print(i)
#

# random.shuffle(data_list)
#
train_nums = int(len(data_list) * 0.9)
dev_nums = int(len(data_list) * 0.1)
#
random.shuffle(data_list)
print(train_nums)
train_list = data_list[:train_nums]
dev_list = data_list[train_nums:]
with open("./data/chatglm_train_3_prompt_llama.json", mode="w", encoding="utf-8") as f:
    f.write(json.dumps(train_list, ensure_ascii=False, indent=2))

with open("./data/chatglm_dev_3_prompt_llama.json", mode="w", encoding="utf-8") as f:
    f.write(json.dumps(dev_list, ensure_ascii=False, indent=2))
# for i in data_tongji_prompt:
#     print(i)
#
# print(data_tongji)

print(jishu)