diff --git a/src/utils/common.py b/src/utils/common.py
index 9082840..b33fa15 100644
--- a/src/utils/common.py
+++ b/src/utils/common.py
@@ -199,7 +199,7 @@ def load_pretrained(
     model = AutoModelForCausalLM.from_pretrained(
         model_args.model_name_or_path,
         config=config,
-        torch_dtype=torch.bfloat16 if finetuning_args.compute_dtype == torch.bfloat16 else torch.float16,
+        torch_dtype=torch.bfloat16 if model_args.compute_dtype == torch.bfloat16 else torch.float16,
         low_cpu_mem_usage=True,
         **config_kwargs
     )