|
@ -36,7 +36,8 @@ from trl import AutoModelForCausalLMWithValueHead |
|
|
from .config import ( |
|
|
from .config import ( |
|
|
ModelArguments, |
|
|
ModelArguments, |
|
|
DataTrainingArguments, |
|
|
DataTrainingArguments, |
|
|
FinetuningArguments |
|
|
FinetuningArguments, |
|
|
|
|
|
GeneratingArguments |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
from .template import Template |
|
|
from .template import Template |
|
@ -54,7 +55,8 @@ check_min_version("4.29.1") |
|
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") |
|
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") |
|
|
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") |
|
|
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") |
|
|
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") |
|
|
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") |
|
|
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") |
|
|
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
@ -91,12 +93,10 @@ def _init_adapter( |
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None: |
|
|
if model_args.checkpoint_dir is not None: |
|
|
if finetuning_args.finetuning_type != "lora": |
|
|
if finetuning_args.finetuning_type != "lora": |
|
|
assert is_mergeable and len( |
|
|
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." |
|
|
model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." |
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods |
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods |
|
|
|
|
|
else: |
|
|
else: |
|
|
assert is_mergeable or len( |
|
|
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." |
|
|
model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." |
|
|
|
|
|
|
|
|
|
|
|
if finetuning_args.finetuning_type == "lora": |
|
|
if finetuning_args.finetuning_type == "lora": |
|
|
logger.info("Fine-tuning method: LoRA") |
|
|
logger.info("Fine-tuning method: LoRA") |
|
@ -106,8 +106,7 @@ def _init_adapter( |
|
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ |
|
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ |
|
|
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." |
|
|
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." |
|
|
|
|
|
|
|
|
if (is_trainable and model_args.resume_lora_training) or ( |
|
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights |
|
|
not is_mergeable): # continually train on the lora weights |
|
|
|
|
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] |
|
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] |
|
|
else: |
|
|
else: |
|
|
checkpoints_to_merge = model_args.checkpoint_dir |
|
|
checkpoints_to_merge = model_args.checkpoint_dir |
|
@ -119,10 +118,10 @@ def _init_adapter( |
|
|
if len(checkpoints_to_merge) > 0: |
|
|
if len(checkpoints_to_merge) > 0: |
|
|
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) |
|
|
logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) |
|
|
|
|
|
|
|
|
if lastest_checkpoint is not None: # resume lora training or quantized inference |
|
|
if lastest_checkpoint is not None: # resume lora training or quantized inference |
|
|
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable) |
|
|
model = PeftModel.from_pretrained(model, lastest_checkpoint, is_trainable=is_trainable) |
|
|
|
|
|
|
|
|
if is_trainable and lastest_checkpoint is None: # create new lora weights while training |
|
|
if is_trainable and lastest_checkpoint is None: # create new lora weights while training |
|
|
lora_config = LoraConfig( |
|
|
lora_config = LoraConfig( |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
inference_mode=False, |
|
|
inference_mode=False, |
|
@ -170,7 +169,7 @@ def load_pretrained( |
|
|
padding_side="left", |
|
|
padding_side="left", |
|
|
**config_kwargs |
|
|
**config_kwargs |
|
|
) |
|
|
) |
|
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token |
|
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) |
|
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) |
|
|
is_mergeable = True |
|
|
is_mergeable = True |
|
@ -186,11 +185,9 @@ def load_pretrained( |
|
|
) |
|
|
) |
|
|
elif model_args.quantization_bit == 4: |
|
|
elif model_args.quantization_bit == 4: |
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") |
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") |
|
|
require_version("transformers>=4.30.0.dev0", |
|
|
require_version("transformers>=4.30.1", "To fix: pip install transformers>=4.30.1") |
|
|
"To fix: pip install git+https://github.com/huggingface/transformers.git") |
|
|
require_version("accelerate>=0.20.3", "To fix: pip install accelerate>=0.20.3") |
|
|
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") |
|
|
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") |
|
|
require_version("accelerate>=0.20.0.dev0", |
|
|
|
|
|
"To fix: pip install git+https://github.com/huggingface/accelerate.git") |
|
|
|
|
|
config_kwargs["load_in_4bit"] = True |
|
|
config_kwargs["load_in_4bit"] = True |
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
load_in_4bit=True, |
|
@ -201,10 +198,10 @@ def load_pretrained( |
|
|
else: |
|
|
else: |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
is_mergeable = False |
|
|
is_mergeable = False |
|
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)} |
|
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} |
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) |
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) |
|
|
|
|
|
|
|
|
if not is_trainable: |
|
|
if not is_trainable: # `device_map=auto` should be used for inference only |
|
|
config_kwargs["device_map"] = "auto" |
|
|
config_kwargs["device_map"] = "auto" |
|
|
|
|
|
|
|
|
# Load and prepare pretrained models (without valuehead). |
|
|
# Load and prepare pretrained models (without valuehead). |
|
@ -218,24 +215,26 @@ def load_pretrained( |
|
|
model = prepare_model_for_training(model) if is_trainable else model |
|
|
model = prepare_model_for_training(model) if is_trainable else model |
|
|
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) |
|
|
model = _init_adapter(model, model_args, finetuning_args, is_trainable, is_mergeable) |
|
|
|
|
|
|
|
|
if stage == "rm" or stage == "ppo": # add value head |
|
|
if stage == "rm" or stage == "ppo": # add value head |
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) |
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) |
|
|
|
|
|
|
|
|
if stage == "ppo": # load reward model |
|
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model |
|
|
|
|
|
load_valuehead_params(model, model_args.checkpoint_dir[0]) |
|
|
|
|
|
model.v_head.load_state_dict({ |
|
|
|
|
|
"summary.weight": getattr(model, "reward_head_weight"), |
|
|
|
|
|
"summary.bias": getattr(model, "reward_head_bias") |
|
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
if stage == "ppo": # load reward model |
|
|
assert is_trainable, "PPO stage cannot be performed at evaluation." |
|
|
assert is_trainable, "PPO stage cannot be performed at evaluation." |
|
|
assert model_args.reward_model is not None, "Reward model is necessary for PPO training." |
|
|
assert model_args.reward_model is not None, "Reward model is necessary for PPO training." |
|
|
logger.info("Load reward model from {}".format(model_args.reward_model)) |
|
|
logger.info("Load reward model from {}".format(model_args.reward_model)) |
|
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) |
|
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) |
|
|
load_valuehead_params(model, model_args.reward_model) |
|
|
load_valuehead_params(model, model_args.reward_model) |
|
|
|
|
|
|
|
|
# Set the parameter _is_int8_training_enabled for the AutoModelForCausalLMWithValueHead model |
|
|
|
|
|
# To meet the compliance requirements of the transformers library |
|
|
|
|
|
if model_args.quantization_bit is not None: |
|
|
|
|
|
model._is_int8_training_enabled = True |
|
|
|
|
|
|
|
|
|
|
|
if not is_trainable: |
|
|
if not is_trainable: |
|
|
model.requires_grad_(False) # fix all model params |
|
|
model.requires_grad_(False) # fix all model params |
|
|
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 |
|
|
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 |
|
|
|
|
|
|
|
|
print_trainable_params(model) |
|
|
print_trainable_params(model) |
|
|
|
|
|
|
|
@ -245,11 +244,11 @@ def load_pretrained( |
|
|
def prepare_args( |
|
|
def prepare_args( |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: |
|
|
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: |
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) |
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) |
|
|
|
|
|
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_json_file( |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
|
json_file=os.path.abspath(sys.argv[1])) |
|
|
|
|
|
else: |
|
|
else: |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
@ -290,7 +289,7 @@ def prepare_args( |
|
|
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") |
|
|
logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") |
|
|
training_args.ddp_find_unused_parameters = False |
|
|
training_args.ddp_find_unused_parameters = False |
|
|
|
|
|
|
|
|
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning |
|
|
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning |
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None: |
|
|
if model_args.quantization_bit is not None: |
|
|
if training_args.fp16: |
|
|
if training_args.fp16: |
|
@ -313,13 +312,14 @@ def prepare_args( |
|
|
return model_args, data_args, training_args, finetuning_args |
|
|
return model_args, data_args, training_args, finetuning_args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: |
|
|
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments]: |
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) |
|
|
|
|
|
|
|
|
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments, GeneratingArguments)) |
|
|
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
|
|
|
|
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
|
|
|
model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
|
else: |
|
|
else: |
|
|
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": |
|
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": |
|
|
raise ValueError("Quantization is only compatible with the LoRA method.") |
|
|
raise ValueError("Quantization is only compatible with the LoRA method.") |
|
@ -327,13 +327,14 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun |
|
|
if data_args.prompt_template == "alpaca": |
|
|
if data_args.prompt_template == "alpaca": |
|
|
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") |
|
|
logger.warning("Please specify `prompt_template` if you are using other pre-trained models.") |
|
|
|
|
|
|
|
|
return model_args, data_args, finetuning_args |
|
|
return model_args, data_args, finetuning_args, generating_args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_data( |
|
|
def prepare_data( |
|
|
model_args: ModelArguments, |
|
|
model_args: ModelArguments, |
|
|
data_args: DataTrainingArguments |
|
|
data_args: DataTrainingArguments |
|
|
) -> Dataset: |
|
|
) -> Dataset: |
|
|
|
|
|
|
|
|
def checksum(file_path, hash): |
|
|
def checksum(file_path, hash): |
|
|
with open(file_path, "rb") as datafile: |
|
|
with open(file_path, "rb") as datafile: |
|
|
binary_data = datafile.read() |
|
|
binary_data = datafile.read() |
|
@ -342,7 +343,7 @@ def prepare_data( |
|
|
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) |
|
|
logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) |
|
|
|
|
|
|
|
|
max_samples = data_args.max_samples |
|
|
max_samples = data_args.max_samples |
|
|
all_datasets: List[Dataset] = [] # support multiple datasets |
|
|
all_datasets: List[Dataset] = [] # support multiple datasets |
|
|
|
|
|
|
|
|
for dataset_attr in data_args.dataset_list: |
|
|
for dataset_attr in data_args.dataset_list: |
|
|
|
|
|
|
|
@ -358,10 +359,12 @@ def prepare_data( |
|
|
elif dataset_attr.load_from == "file": |
|
|
elif dataset_attr.load_from == "file": |
|
|
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) |
|
|
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) |
|
|
extension = dataset_attr.file_name.split(".")[-1] |
|
|
extension = dataset_attr.file_name.split(".")[-1] |
|
|
|
|
|
|
|
|
if dataset_attr.file_sha1 is not None: |
|
|
if dataset_attr.file_sha1 is not None: |
|
|
checksum(data_file, dataset_attr.file_sha1) |
|
|
checksum(data_file, dataset_attr.file_sha1) |
|
|
else: |
|
|
else: |
|
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") |
|
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") |
|
|
|
|
|
|
|
|
raw_datasets = load_dataset( |
|
|
raw_datasets = load_dataset( |
|
|
extension if extension in ["csv", "json"] else "text", |
|
|
extension if extension in ["csv", "json"] else "text", |
|
|
data_files=data_file, |
|
|
data_files=data_file, |
|
@ -383,11 +386,11 @@ def prepare_data( |
|
|
("query_column", "query"), |
|
|
("query_column", "query"), |
|
|
("response_column", "response"), |
|
|
("response_column", "response"), |
|
|
("history_column", "history") |
|
|
("history_column", "history") |
|
|
]: # every dataset will have 4 columns same as each other |
|
|
]: # every dataset will have 4 columns same as each other |
|
|
if getattr(dataset_attr, column_name) != target_name: |
|
|
if getattr(dataset_attr, column_name) != target_name: |
|
|
if getattr(dataset_attr, column_name): |
|
|
if getattr(dataset_attr, column_name): |
|
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) |
|
|
dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) |
|
|
else: # None or empty string |
|
|
else: # None or empty string |
|
|
dataset = dataset.add_column(target_name, dummy_data) |
|
|
dataset = dataset.add_column(target_name, dummy_data) |
|
|
all_datasets.append(dataset) |
|
|
all_datasets.append(dataset) |
|
|
|
|
|
|
|
@ -406,6 +409,7 @@ def preprocess_data( |
|
|
training_args: Seq2SeqTrainingArguments, |
|
|
training_args: Seq2SeqTrainingArguments, |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
) -> Dataset: |
|
|
) -> Dataset: |
|
|
|
|
|
|
|
|
column_names = list(dataset.column_names) |
|
|
column_names = list(dataset.column_names) |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
prompt_template = Template(data_args.prompt_template) |
|
|
prompt_template = Template(data_args.prompt_template) |
|
@ -442,9 +446,9 @@ def preprocess_data( |
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) |
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) |
|
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=False) |
|
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=False) |
|
|
|
|
|
|
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token |
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token |
|
|
source_ids = source_ids[:data_args.max_source_length - 1] |
|
|
source_ids = source_ids[:data_args.max_source_length - 1] |
|
|
if len(target_ids) > data_args.max_target_length - 1: # eos token |
|
|
if len(target_ids) > data_args.max_target_length - 1: # eos token |
|
|
target_ids = target_ids[:data_args.max_target_length - 1] |
|
|
target_ids = target_ids[:data_args.max_target_length - 1] |
|
|
|
|
|
|
|
|
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] |
|
|
input_ids = source_ids + [tokenizer.bos_token_id] + target_ids + [tokenizer.eos_token_id] |
|
@ -461,9 +465,9 @@ def preprocess_data( |
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) |
|
|
source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) |
|
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=False) |
|
|
target_ids = tokenizer.encode(text=answer, add_special_tokens=False) |
|
|
|
|
|
|
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token |
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token |
|
|
source_ids = source_ids[:data_args.max_source_length - 1] |
|
|
source_ids = source_ids[:data_args.max_source_length - 1] |
|
|
if len(target_ids) > data_args.max_target_length - 1: # bos token |
|
|
if len(target_ids) > data_args.max_target_length - 1: # bos token |
|
|
target_ids = target_ids[:data_args.max_target_length - 1] |
|
|
target_ids = target_ids[:data_args.max_target_length - 1] |
|
|
|
|
|
|
|
|
input_ids = source_ids + [tokenizer.bos_token_id] |
|
|
input_ids = source_ids + [tokenizer.bos_token_id] |
|
@ -481,11 +485,11 @@ def preprocess_data( |
|
|
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) |
|
|
accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) |
|
|
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) |
|
|
reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) |
|
|
|
|
|
|
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token |
|
|
if len(source_ids) > data_args.max_source_length - 1: # bos token |
|
|
source_ids = source_ids[:data_args.max_source_length - 1] |
|
|
source_ids = source_ids[:data_args.max_source_length - 1] |
|
|
if len(accept_ids) > data_args.max_target_length - 1: # eos token |
|
|
if len(accept_ids) > data_args.max_target_length - 1: # eos token |
|
|
accept_ids = accept_ids[:data_args.max_target_length - 1] |
|
|
accept_ids = accept_ids[:data_args.max_target_length - 1] |
|
|
if len(reject_ids) > data_args.max_target_length - 1: # eos token |
|
|
if len(reject_ids) > data_args.max_target_length - 1: # eos token |
|
|
reject_ids = reject_ids[:data_args.max_target_length - 1] |
|
|
reject_ids = reject_ids[:data_args.max_target_length - 1] |
|
|
|
|
|
|
|
|
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] |
|
|
accept_ids = source_ids + [tokenizer.bos_token_id] + accept_ids + [tokenizer.eos_token_id] |
|
|