|
@ -430,15 +430,16 @@ def preprocess_data( |
|
|
yield dialog |
|
|
yield dialog |
|
|
|
|
|
|
|
|
def preprocess_pretrain_dataset(examples): |
|
|
def preprocess_pretrain_dataset(examples): |
|
|
# build grouped texts with format `X1 X2 X3 ...` (without [BOS] and [EOS]) |
|
|
# build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS]) |
|
|
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] |
|
|
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] |
|
|
concatenated_ids = list(chain(*text_ids)) |
|
|
concatenated_ids = list(chain(*text_ids)) |
|
|
total_length = len(concatenated_ids) |
|
|
total_length = len(concatenated_ids) |
|
|
|
|
|
block_size = data_args.max_source_length - 1 |
|
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch |
|
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch |
|
|
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length |
|
|
total_length = (total_length // block_size) * block_size |
|
|
# split by chunks of max_source_length |
|
|
# split by chunks of max_source_length |
|
|
result = [concatenated_ids[i: i + data_args.max_source_length] for i in |
|
|
result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] |
|
|
range(0, total_length, data_args.max_source_length)] |
|
|
for i in range(0, total_length, block_size)] |
|
|
return { |
|
|
return { |
|
|
"input_ids": result, |
|
|
"input_ids": result, |
|
|
"labels": result.copy() |
|
|
"labels": result.copy() |
|
|