|
|
@ -41,7 +41,7 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def |
|
|
|
|
|
|
|
def cast_layernorm_dtype( |
|
|
|
model: AutoModelForCausalLMWithValueHead, |
|
|
|
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting |
|
|
|
layer_norm_names: List[str] = ["norm", "ln_f"], # for LLaMA and BLOOM setting |
|
|
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None |
|
|
|
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: |
|
|
|
|
|
|
|