From d668f8b501c367276ef4be372f2eb1753a1b7e86 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Thu, 15 Jun 2023 01:46:17 +0800 Subject: [PATCH] add BOS token in pre-training --- src/utils/common.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/utils/common.py b/src/utils/common.py index e63c165..38ae830 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -430,15 +430,16 @@ def preprocess_data( yield dialog def preprocess_pretrain_dataset(examples): - # build grouped texts with format `X1 X2 X3 ...` (without [BOS] and [EOS]) + # build grouped texts with format `[BOS] X1 X2 X3 ...` (without [EOS]) text_ids = tokenizer(examples["prompt"], add_special_tokens=False)["input_ids"] concatenated_ids = list(chain(*text_ids)) total_length = len(concatenated_ids) + block_size = data_args.max_source_length - 1 # we drop the small remainder, and if the total_length < block_size, we exclude this batch - total_length = (total_length // data_args.max_source_length) * data_args.max_source_length + total_length = (total_length // block_size) * block_size # split by chunks of max_source_length - result = [concatenated_ids[i: i + data_args.max_source_length] for i in - range(0, total_length, data_args.max_source_length)] + result = [[tokenizer.bos_token_id] + concatenated_ids[i: i + block_size] + for i in range(0, total_length, block_size)] return { "input_ids": result, "labels": result.copy()