Browse Source

tiny fix

main
hiyouga 2 years ago
parent
commit
eac9921e5c
  1. 16
      src/utils/common.py
  2. 8
      src/utils/config.py

16
src/utils/common.py

@ -146,7 +146,7 @@ def load_pretrained(
finetuning_args = FinetuningArguments(finetuning_type="none") finetuning_args = FinetuningArguments(finetuning_type="none")
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
"RM and PPO training can only be performed with LoRA method." "RM and PPO training can only be performed with the LoRA method."
config_kwargs = { config_kwargs = {
"trust_remote_code": True, "trust_remote_code": True,
@ -183,7 +183,7 @@ def load_pretrained(
config_kwargs["load_in_4bit"] = True config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig( config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
bnb_4bit_compute_dtype=finetuning_args.compute_dtype, bnb_4bit_compute_dtype=model_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization, bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type bnb_4bit_quant_type=model_args.quantization_type
) )
@ -261,6 +261,9 @@ def prepare_args(
if training_args.do_predict and (not training_args.predict_with_generate): if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` to save model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.quantization_bit is not None and (not training_args.do_train): if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -275,11 +278,11 @@ def prepare_args(
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
if training_args.fp16: if training_args.fp16:
finetuning_args.compute_dtype = torch.float16 model_args.compute_dtype = torch.float16
elif training_args.bf16: elif training_args.bf16:
finetuning_args.compute_dtype = torch.bfloat16 model_args.compute_dtype = torch.bfloat16
else: else:
finetuning_args.compute_dtype = torch.float32 model_args.compute_dtype = torch.float32
# Log on each process the small summary: # Log on each process the small summary:
logger.info( logger.info(
@ -303,6 +306,9 @@ def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, Finetun
else: else:
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
return model_args, data_args, finetuning_args return model_args, data_args, finetuning_args

8
src/utils/config.py

@ -62,6 +62,10 @@ class ModelArguments:
default=True, default=True,
metadata={"help": "Compress the quantization statistics through double quantization."} metadata={"help": "Compress the quantization statistics through double quantization."}
) )
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
checkpoint_dir: Optional[str] = field( checkpoint_dir: Optional[str] = field(
default=None, default=None,
metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} metadata={"help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."}
@ -208,10 +212,6 @@ class FinetuningArguments:
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"} BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
) )
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
def __post_init__(self): def __post_init__(self):
if isinstance(self.lora_target, str): if isinstance(self.lora_target, str):

Loading…
Cancel
Save