|
|
@ -94,7 +94,7 @@ def _init_adapter( |
|
|
|
if model_args.checkpoint_dir is not None: |
|
|
|
if finetuning_args.finetuning_type != "lora": |
|
|
|
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." |
|
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods |
|
|
|
assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." |
|
|
|
else: |
|
|
|
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." |
|
|
|
|
|
|
@ -217,18 +217,19 @@ def load_pretrained( |
|
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) |
|
|
|
|
|
|
|
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model |
|
|
|
load_valuehead_params(model, model_args.checkpoint_dir[0]) |
|
|
|
model.v_head.load_state_dict({ |
|
|
|
"summary.weight": getattr(model, "reward_head_weight"), |
|
|
|
"summary.bias": getattr(model, "reward_head_bias") |
|
|
|
}) |
|
|
|
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") |
|
|
|
if load_valuehead_params(model, model_args.checkpoint_dir[-1]): |
|
|
|
model.v_head.load_state_dict({ |
|
|
|
"summary.weight": getattr(model, "reward_head_weight"), |
|
|
|
"summary.bias": getattr(model, "reward_head_bias") |
|
|
|
}) |
|
|
|
|
|
|
|
if stage == "ppo": # load reward model |
|
|
|
assert is_trainable, "PPO stage cannot be performed at evaluation." |
|
|
|
assert model_args.reward_model is not None, "Reward model is necessary for PPO training." |
|
|
|
logger.info("Load reward model from {}".format(model_args.reward_model)) |
|
|
|
model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) |
|
|
|
load_valuehead_params(model, model_args.reward_model) |
|
|
|
assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." |
|
|
|
|
|
|
|
if not is_trainable: |
|
|
|
model.requires_grad_(False) # fix all model params |
|
|
|