diff --git a/src/utils/common.py b/src/utils/common.py index a676b66..2f77cba 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -101,10 +101,10 @@ def _init_adapter( logger.info("Fine-tuning method: LoRA") lastest_checkpoint = None - assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ - "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." - if model_args.checkpoint_dir is not None: + assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ + "The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." + if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: