From 875e8e23498f6933d657ad154b53611310327e3e Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 13 Jun 2023 11:13:06 +0800 Subject: [PATCH] fix loading valuehead --- src/utils/common.py | 15 ++++++++------- src/utils/other.py | 14 ++++++++++---- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/utils/common.py b/src/utils/common.py index b17e35d..6cd43aa 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -94,7 +94,7 @@ def _init_adapter( if model_args.checkpoint_dir is not None: if finetuning_args.finetuning_type != "lora": assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." - load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods + assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." else: assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." @@ -217,18 +217,19 @@ def load_pretrained( model = AutoModelForCausalLMWithValueHead.from_pretrained(model) if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model - load_valuehead_params(model, model_args.checkpoint_dir[0]) - model.v_head.load_state_dict({ - "summary.weight": getattr(model, "reward_head_weight"), - "summary.bias": getattr(model, "reward_head_bias") - }) + logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") + if load_valuehead_params(model, model_args.checkpoint_dir[-1]): + model.v_head.load_state_dict({ + "summary.weight": getattr(model, "reward_head_weight"), + "summary.bias": getattr(model, "reward_head_bias") + }) if stage == "ppo": # load reward model assert is_trainable, "PPO stage cannot be performed at evaluation." assert model_args.reward_model is not None, "Reward model is necessary for PPO training." logger.info("Load reward model from {}".format(model_args.reward_model)) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) - load_valuehead_params(model, model_args.reward_model) + assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." if not is_trainable: model.requires_grad_(False) # fix all model params diff --git a/src/utils/other.py b/src/utils/other.py index 5675e3f..838b617 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -126,21 +126,27 @@ def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get sta return filtered_state_dict -def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: +def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) - assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights." + if not os.path.exists(weights_file): + logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir)) + return False model_state_dict = torch.load(weights_file, map_location="cpu") model.load_state_dict(model_state_dict, strict=False) # skip missing keys + return True -def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: +def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) - assert os.path.exists(valuehead_file), f"Provided path ({checkpoint_dir}) does not contain the valuehead weights." + if not os.path.exists(valuehead_file): + logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) + return False valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) + return True def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: