Browse Source

add BOS token in pre-training

main
hiyouga 2 years ago
parent
commit
d668f8b501
  1. 9
      src/utils/common.py

9
src/utils/common.py

@ -430,15 +430,16 @@ def preprocess_data(
yield dialog yield dialog
def preprocess_pretrain_dataset(examples): def preprocess_pretrain_dataset(examples):
# build grouped texts with format `X1 X2 X3 ...` (without [BOS] and [EOS]) # build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS])
text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"]
concatenated_ids = list(chain(*text_ids)) concatenated_ids = list(chain(*text_ids))
total_length = len(concatenated_ids) total_length = len(concatenated_ids)
block_size = data_args.max_source_length - 1
# we drop the small remainder, and if the total_length < block_size, we exclude this batch # we drop the small remainder, and if the total_length < block_size, we exclude this batch
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length total_length = (total_length // block_size) * block_size
# split by chunks of max_source_length # split by chunks of max_source_length
result = [concatenated_ids[i: i + data_args.max_source_length] for i in result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size]
range(0, total_length, data_args.max_source_length)] for i in range(0, total_length, block_size)]
return { return {
"input_ids": result, "input_ids": result,
"labels": result.copy() "labels": result.copy()

Loading…
Cancel
Save