You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
3.4 KiB
108 lines
3.4 KiB
![]()
2 years ago
|
import os
|
||
|
import re
|
||
|
import random
|
||
|
import json
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
RE_CHINA_NUMS = "[一二三四五六七八九].?.?总结|[1-9].?.?总结|[一二三四五六七八九].?.?结论|[1-9].?.?结论"
|
||
|
RE_CHINA_TITLE = "请把其中的小标题“(.*?)”的内容补充完整|请把其中的大标题“(.*?)”的内容补充完整"
|
||
|
|
||
|
data_tongji = {
|
||
|
"0-600": 0,
|
||
|
"600-1500": 0,
|
||
|
"1500-": 0,
|
||
|
}
|
||
|
# print("这段文字翻译成英文"\n'")
|
||
|
data_tongji_prompt = []
|
||
|
|
||
|
def is_contains_chinese(strs):
|
||
|
for _char in strs:
|
||
|
if '\u4e00' <= _char <= '\u9fa5':
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
data_list = []
|
||
|
|
||
|
jishu = 0
|
||
|
|
||
|
with open("data/chatglm_paper_data_2_prompt.txt", encoding="utf-8") as f:
|
||
|
for i in tqdm(f):
|
||
|
data_dan = eval(i)
|
||
|
zishu_query = len(data_dan["query"])
|
||
|
zishu_response = len(data_dan["response"])
|
||
|
|
||
|
prompt = str(data_dan["prompt"]).replace("\\n", "\n")
|
||
|
query = data_dan["query"].replace("\\n", "\n")
|
||
|
response = data_dan["response"].replace("\\n", "\n")
|
||
|
|
||
|
if prompt == "翻译摘要#":
|
||
|
zishu_summary = len(response.split(" "))
|
||
|
elif prompt == "翻译关键词#":
|
||
|
zishu_summary = len(response.split(" "))
|
||
|
else:
|
||
|
bool_ = is_contains_chinese(response)
|
||
|
if bool_ == False:
|
||
|
print(data_dan)
|
||
|
continue
|
||
|
|
||
|
if "生成方向" in query:
|
||
|
query = query.replace("生成方向","研究方向")
|
||
|
if "生成方向" in response:
|
||
|
response = response.replace("生成方向", "研究方向")
|
||
|
|
||
|
if prompt == "生成论文小标题内容#":
|
||
|
query_re = re.findall(RE_CHINA_TITLE, query)
|
||
|
if "总结" not in query_re[0] or "结论" not in query_re[0]:
|
||
|
response_re = re.findall(RE_CHINA_NUMS, response)
|
||
|
if response_re != []:
|
||
|
print(response)
|
||
|
print("==========================================================================================")
|
||
|
jishu += 1
|
||
|
|
||
|
if prompt[-1] != "\n":
|
||
|
prompt += "\n"
|
||
|
if query[-1] != "\n":
|
||
|
query += "\n"
|
||
|
query = "问:" + query + "答:\n"
|
||
|
|
||
|
if len(query) < 700 and len(response) < 1400:
|
||
|
data_list.append({
|
||
|
"instruction": prompt,
|
||
|
"input": query,
|
||
|
"output": response
|
||
|
})
|
||
|
# if zishu_summary < 600:
|
||
|
# data_tongji["0-600"] += 1
|
||
|
# if 600 < zishu_summary < 1500:
|
||
|
# data_tongji["600-1500"] += 1
|
||
|
# if 1500 < zishu_summary:
|
||
|
# data_tongji["1500-"] += 1
|
||
|
# data_tongji_prompt.append([data_dan['summary'], zishu_summary])
|
||
|
# else:
|
||
|
# train_list.append(i)
|
||
|
|
||
|
|
||
|
# for i in data_tongji_prompt:
|
||
|
# print(i)
|
||
|
#
|
||
|
|
||
|
# random.shuffle(data_list)
|
||
|
#
|
||
|
train_nums = int(len(data_list) * 0.9)
|
||
|
dev_nums = int(len(data_list) * 0.1)
|
||
|
#
|
||
|
random.shuffle(data_list)
|
||
|
print(train_nums)
|
||
|
train_list = data_list[:train_nums]
|
||
|
dev_list = data_list[train_nums:]
|
||
|
with open("./data/chatglm_train_3_prompt_llama.json", mode="w", encoding="utf-8") as f:
|
||
|
f.write(json.dumps(train_list, ensure_ascii=False, indent=2))
|
||
|
|
||
|
with open("./data/chatglm_dev_3_prompt_llama.json", mode="w", encoding="utf-8") as f:
|
||
|
f.write(json.dumps(dev_list, ensure_ascii=False, indent=2))
|
||
|
# for i in data_tongji_prompt:
|
||
|
# print(i)
|
||
|
#
|
||
|
# print(data_tongji)
|
||
|
|
||
|
print(jishu)
|