|
@ -146,7 +146,7 @@ def load_pretrained( |
|
|
finetuning_args = FinetuningArguments(finetuning_type="none") |
|
|
finetuning_args = FinetuningArguments(finetuning_type="none") |
|
|
|
|
|
|
|
|
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ |
|
|
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ |
|
|
"RM and PPO training can only be performed with LoRA method." |
|
|
"RM and PPO training can only be performed with the LoRA method." |
|
|
|
|
|
|
|
|
config_kwargs = { |
|
|
config_kwargs = { |
|
|
"trust_remote_code": True, |
|
|
"trust_remote_code": True, |
|
@ -183,7 +183,7 @@ def load_pretrained( |
|
|
config_kwargs["load_in_4bit"] = True |
|
|
config_kwargs["load_in_4bit"] = True |
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=finetuning_args.compute_dtype, |
|
|
bnb_4bit_compute_dtype=model_args.compute_dtype, |
|
|
bnb_4bit_use_double_quant=model_args.double_quantization, |
|
|
bnb_4bit_use_double_quant=model_args.double_quantization, |
|
|
bnb_4bit_quant_type=model_args.quantization_type |
|
|
bnb_4bit_quant_type=model_args.quantization_type |
|
|
) |
|
|
) |
|
@ -261,6 +261,9 @@ def prepare_args( |
|
|
if training_args.do_predict and (not training_args.predict_with_generate): |
|
|
if training_args.do_predict and (not training_args.predict_with_generate): |
|
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.") |
|
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.") |
|
|
|
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": |
|
|
|
|
|
raise ValueError("Quantization is only compatible with the LoRA method.") |
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None and (not training_args.do_train): |
|
|
if model_args.quantization_bit is not None and (not training_args.do_train): |
|
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") |
|
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") |
|
|
|
|
|
|
|
@ -275,11 +278,11 @@ def prepare_args( |
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None: |
|
|
if model_args.quantization_bit is not None: |
|
|
if training_args.fp16: |
|
|
if training_args.fp16: |
|
|
finetuning_args.compute_dtype = torch.float16 |
|
|
model_args.compute_dtype = torch.float16 |
|
|
elif training_args.bf16: |
|
|
elif training_args.bf16: |
|
|
finetuning_args.compute_dtype = torch.bfloat16 |
|
|
model_args.compute_dtype = torch.bfloat16 |
|
|
else: |
|
|
else: |
|
|
finetuning_args.compute_dtype = torch.float32 |
|
|
model_args.compute_dtype = torch.float32 |
|
|
|
|
|
|
|
|
# Log on each process the small summary: |
|
|
# Log on each process the small summary: |
|
|
logger.info( |
|
|
logger.info( |
|
@ -303,6 +306,9 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun |
|
|
else: |
|
|
else: |
|
|
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": |
|
|
|
|
|
raise ValueError("Quantization is only compatible with the LoRA method.") |
|
|
|
|
|
|
|
|
return model_args, data_args, finetuning_args |
|
|
return model_args, data_args, finetuning_args |
|
|
|
|
|
|
|
|
|
|
|
|
|
|