数据处理代码,为了生成chatgpt数据
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

91 lines
2.5 KiB

2 years ago
import os
import random
import json
from tqdm import tqdm
data_tongji = {
"0-600": 0,
"600-1500": 0,
"1500-": 0,
}
# print("这段文字翻译成英文"\n'")
data_tongji_prompt = []
def is_contains_chinese(strs):
for _char in strs:
if '\u4e00' <= _char <= '\u9fa5':
return True
return False
data_list = []
with open("data/chatglm_paper_data_2_prompt.txt", encoding="utf-8") as f:
for i in tqdm(f):
data_dan = eval(i)
zishu_query = len(data_dan["query"])
zishu_response = len(data_dan["response"])
query = data_dan["query"]
response = data_dan["response"]
prompt = data_dan["prompt"]
if prompt == "翻译摘要#":
zishu_summary = len(data_dan["response"].split(" "))
elif prompt == "翻译关键词#":
zishu_summary = len(data_dan["response"].split(" "))
else:
bool_ = is_contains_chinese(data_dan["response"])
if bool_ == False:
print(data_dan)
continue
if "生成方向" in query:
data_dan["query"] = query.replace("生成方向","研究方向")
if "生成方向" in response:
data_dan["response"] = response.replace("生成方向", "研究方向")
if zishu_query < 700 and zishu_response< 1400:
data_dan_dict = {
"text" :"Bob: " + data_dan["query"] + "\n\nAlice: "+ data_dan["response"]
}
data_list.append(json.dumps(data_dan_dict, ensure_ascii=False))
# if zishu_summary < 600:
# data_tongji["0-600"] += 1
# if 600 < zishu_summary < 1500:
# data_tongji["600-1500"] += 1
# if 1500 < zishu_summary:
# data_tongji["1500-"] += 1
# data_tongji_prompt.append([data_dan['summary'], zishu_summary])
# else:
# train_list.append(i)
# for i in data_tongji_prompt:
# print(i)
#
# random.shuffle(data_list)
#
train_nums = int(len(data_list) * 0.8)
dev_nums = int(len(data_list) * 0.2)
#
random.shuffle(data_list)
print(train_nums)
train_list = data_list[:train_nums]
dev_list = data_list[train_nums:]
with open("./data/chatglm_train_3_chatrwkv.jsonl", mode="w", encoding="utf-8") as f:
for i in train_list:
f.write(i)
f.write("\n")
with open("./data/chatglm_dev_3_chatrwkv.jsonl", mode="w", encoding="utf-8") as f:
for i in dev_list:
f.write(i)
f.write("\n")
# for i in data_tongji_prompt:
# print(i)
#
# print(data_tongji)