import os import re import random import json from tqdm import tqdm RE_CHINA_NUMS = "[一二三四五六七八九].?.?总结|[1-9].?.?总结|[一二三四五六七八九].?.?结论|[1-9].?.?结论" RE_CHINA_TITLE = "请把其中的小标题“(.*?)”的内容补充完整|请把其中的大标题“(.*?)”的内容补充完整" data_tongji = { "0-600": 0, "600-1500": 0, "1500-": 0, } # print("这段文字翻译成英文"\n'") data_tongji_prompt = [] def is_contains_chinese(strs): for _char in strs: if '\u4e00' <= _char <= '\u9fa5': return True return False data_list = [] jishu = 0 with open("data/chatglm_paper_data_2_prompt.txt", encoding="utf-8") as f: for i in tqdm(f): data_dan = eval(i) zishu_query = len(data_dan["query"]) zishu_response = len(data_dan["response"]) prompt = str(data_dan["prompt"]).replace("\\n", "\n") query = data_dan["query"].replace("\\n", "\n") response = data_dan["response"].replace("\\n", "\n") if prompt == "翻译摘要#": zishu_summary = len(response.split(" ")) elif prompt == "翻译关键词#": zishu_summary = len(response.split(" ")) else: bool_ = is_contains_chinese(response) if bool_ == False: print(data_dan) continue if "生成方向" in query: query = query.replace("生成方向","研究方向") if "生成方向" in response: response = response.replace("生成方向", "研究方向") if prompt == "生成论文小标题内容#": query_re = re.findall(RE_CHINA_TITLE, query) if "总结" not in query_re[0] or "结论" not in query_re[0]: response_re = re.findall(RE_CHINA_NUMS, response) if response_re != []: print(response) print("==========================================================================================") jishu += 1 if prompt[-1] != "\n": prompt += "\n" if query[-1] != "\n": query += "\n" query = "问:" + query + "答:\n" if len(query) < 700 and len(response) < 1400: data_list.append({ "instruction": prompt, "input": query, "output": response }) # if zishu_summary < 600: # data_tongji["0-600"] += 1 # if 600 < zishu_summary < 1500: # data_tongji["600-1500"] += 1 # if 1500 < zishu_summary: # data_tongji["1500-"] += 1 # data_tongji_prompt.append([data_dan['summary'], zishu_summary]) # else: # train_list.append(i) # for i in data_tongji_prompt: # print(i) # # random.shuffle(data_list) # train_nums = int(len(data_list) * 0.9) dev_nums = int(len(data_list) * 0.1) # random.shuffle(data_list) print(train_nums) train_list = data_list[:train_nums] dev_list = data_list[train_nums:] with open("./data/chatglm_train_3_prompt_llama.json", mode="w", encoding="utf-8") as f: f.write(json.dumps(train_list, ensure_ascii=False, indent=2)) with open("./data/chatglm_dev_3_prompt_llama.json", mode="w", encoding="utf-8") as f: f.write(json.dumps(dev_list, ensure_ascii=False, indent=2)) # for i in data_tongji_prompt: # print(i) # # print(data_tongji) print(jishu)