|
|
@ -199,7 +199,7 @@ def load_pretrained( |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
|
model_args.model_name_or_path, |
|
|
|
config=config, |
|
|
|
torch_dtype=torch.bfloat16 if finetuning_args.compute_dtype == torch.bfloat16 else torch.float16, |
|
|
|
torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16, |
|
|
|
low_cpu_mem_usage=True, |
|
|
|
**config_kwargs |
|
|
|
) |
|
|
|