From ce71cc8b6db5d13b87b7d0302f4176c5c76ac4b2 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 29 May 2023 09:42:29 +0800 Subject: [PATCH] tiny fix --- src/train_ppo.py | 2 +- src/train_rm.py | 2 +- src/train_sft.py | 2 +- src/utils/common.py | 2 +- src/utils/config.py | 16 ++++++++-------- 5 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/train_ppo.py b/src/train_ppo.py index f6e57c0..c4bd8bc 100644 --- a/src/train_ppo.py +++ b/src/train_ppo.py @@ -69,7 +69,7 @@ def main(): ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) ppo_trainer.save_model() ppo_trainer.save_state() # must be after save_model - if ppo_trainer.is_world_process_zero() and finetuning_args.plot_loss: + if ppo_trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args, keys=["loss", "reward"]) diff --git a/src/train_rm.py b/src/train_rm.py index ecd7f71..812a643 100644 --- a/src/train_rm.py +++ b/src/train_rm.py @@ -55,7 +55,7 @@ def main(): trainer.save_metrics("train", train_result.metrics) trainer.save_state() trainer.save_model() - if trainer.is_world_process_zero() and finetuning_args.plot_loss: + if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args, keys=["loss", "eval_loss"]) # Evaluation diff --git a/src/train_sft.py b/src/train_sft.py index 3bc0f85..16c91f0 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -71,7 +71,7 @@ def main(): trainer.save_metrics("train", train_result.metrics) trainer.save_state() trainer.save_model() - if trainer.is_world_process_zero() and finetuning_args.plot_loss: + if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args, keys=["loss", "eval_loss"]) # Evaluation diff --git a/src/utils/common.py b/src/utils/common.py index 2798124..a7b1447 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -91,7 +91,7 @@ def init_adapter( lastest_checkpoint = None if model_args.checkpoint_dir is not None: - if is_trainable and finetuning_args.resume_lora_training: # continually train on the lora weights + if is_trainable and model_args.resume_lora_training: # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] else: checkpoints_to_merge = model_args.checkpoint_dir diff --git a/src/utils/config.py b/src/utils/config.py index fe35a6e..b7012a6 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -51,6 +51,14 @@ class ModelArguments: default=None, metadata={"help": "Path to the directory containing the checkpoints of the reward model."} ) + resume_lora_training: Optional[bool] = field( + default=True, + metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} + ) + plot_loss: Optional[bool] = field( + default=False, + metadata={"help": "Whether to plot the training loss after fine-tuning or not."} + ) def __post_init__(self): if self.checkpoint_dir is not None: # support merging lora weights @@ -173,14 +181,6 @@ class FinetuningArguments: default="q_proj,v_proj", metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."} ) - resume_lora_training: Optional[bool] = field( - default=True, - metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} - ) - plot_loss: Optional[bool] = field( - default=False, - metadata={"help": "Whether to plot the training loss after fine-tuning or not."} - ) def __post_init__(self): if isinstance(self.lora_target, str):