Browse Source

fix loading valuehead

main
hiyouga 2 years ago
parent
commit
875e8e2349
  1. 15
      src/utils/common.py
  2. 14
      src/utils/other.py

15
src/utils/common.py

@ -94,7 +94,7 @@ def _init_adapter(
if model_args.checkpoint_dir is not None: if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora": if finetuning_args.finetuning_type != "lora":
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
else: else:
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint."
@ -217,18 +217,19 @@ def load_pretrained(
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
load_valuehead_params(model, model_args.checkpoint_dir[0]) logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
model.v_head.load_state_dict({ if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
"summary.weight": getattr(model, "reward_head_weight"), model.v_head.load_state_dict({
"summary.bias": getattr(model, "reward_head_bias") "summary.weight": getattr(model, "reward_head_weight"),
}) "summary.bias": getattr(model, "reward_head_bias")
})
if stage == "ppo": # load reward model if stage == "ppo": # load reward model
assert is_trainable, "PPO stage cannot be performed at evaluation." assert is_trainable, "PPO stage cannot be performed at evaluation."
assert model_args.reward_model is not None, "Reward model is necessary for PPO training." assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
logger.info("Load reward model from {}".format(model_args.reward_model)) logger.info("Load reward model from {}".format(model_args.reward_model))
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
load_valuehead_params(model, model_args.reward_model) assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
if not is_trainable: if not is_trainable:
model.requires_grad_(False) # fix all model params model.requires_grad_(False) # fix all model params

14
src/utils/other.py

@ -126,21 +126,27 @@ def get_state_dict(model: torch.nn.Module) -> Dict[str, torch.Tensor]: # get sta
return filtered_state_dict return filtered_state_dict
def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
assert os.path.exists(weights_file), f"Provided path ({checkpoint_dir}) does not contain the pretrained weights." if not os.path.exists(weights_file):
logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
return False
model_state_dict = torch.load(weights_file, map_location="cpu") model_state_dict = torch.load(weights_file, map_location="cpu")
model.load_state_dict(model_state_dict, strict=False) # skip missing keys model.load_state_dict(model_state_dict, strict=False) # skip missing keys
return True
def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> None: def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
assert os.path.exists(valuehead_file), f"Provided path ({checkpoint_dir}) does not contain the valuehead weights." if not os.path.exists(valuehead_file):
logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
return False
valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
return True
def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]: def smooth(scalars: List[float], weight: Optional[float] = 0.9) -> List[float]:

Loading…
Cancel
Save