@ -11,7 +11,8 @@ from transformers import (
AutoModelForCausalLM ,
AutoModelForCausalLM ,
AutoTokenizer ,
AutoTokenizer ,
HfArgumentParser ,
HfArgumentParser ,
Seq2SeqTrainingArguments
Seq2SeqTrainingArguments ,
BitsAndBytesConfig
)
)
from transformers . utils import check_min_version
from transformers . utils import check_min_version
from transformers . utils . versions import require_version
from transformers . utils . versions import require_version
@ -167,12 +168,27 @@ def load_pretrained(
# Quantization configurations (using bitsandbytes library).
# Quantization configurations (using bitsandbytes library).
if model_args . quantization_bit is not None :
if model_args . quantization_bit is not None :
assert model_args . quantization_bit == 8 , " We only accept 8-bit quantization. "
if model_args . quantization_bit == 8 :
require_version ( " bitsandbytes>=0.39.0 " , " To fix: pip install bitsandbytes>=0.39.0 " )
require_version ( " bitsandbytes>=0.37.0 " , " To fix: pip install bitsandbytes>=0.37.0 " )
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs [ " load_in_8bit " ] = True
config_kwargs [ " load_in_8bit " ] = True
config_kwargs [ " quantization_config " ] = BitsAndBytesConfig (
load_in_8bit = True ,
llm_int8_threshold = 6.0
)
elif model_args . quantization_bit == 4 :
require_version ( " bitsandbytes>=0.39.0 " , " To fix: pip install bitsandbytes>=0.39.0 " )
require_version ( " transformers>=4.30.0.dev0 " , " To fix: pip install git+https://github.com/huggingface/transformers.git " )
require_version ( " peft>=0.4.0.dev0 " , " To fix: pip install git+https://github.com/huggingface/peft.git " )
require_version ( " accelerate>=0.20.0.dev0 " , " To fix: pip install git+https://github.com/huggingface/accelerate.git " )
config_kwargs [ " load_in_4bit " ] = True
config_kwargs [ " quantization_config " ] = BitsAndBytesConfig (
load_in_4bit = True ,
bnb_4bit_compute_dtype = finetuning_args . compute_dtype ,
bnb_4bit_use_double_quant = model_args . double_quantization ,
bnb_4bit_quant_type = model_args . quantization_type
)
else :
raise NotImplementedError
is_mergeable = False
is_mergeable = False
logger . info ( " Quantizing model to {} bit. " . format ( model_args . quantization_bit ) )
logger . info ( " Quantizing model to {} bit. " . format ( model_args . quantization_bit ) )
@ -183,7 +199,7 @@ def load_pretrained(
model = AutoModelForCausalLM . from_pretrained (
model = AutoModelForCausalLM . from_pretrained (
model_args . model_name_or_path ,
model_args . model_name_or_path ,
config = config ,
config = config ,
torch_dtype = torch . float16 , # the model weights are float16 type
torch_dtype = torch . bfloat16 if finetuning_args . compute_dtype == torch . bfloat16 else torch . float16 ,
low_cpu_mem_usage = True ,
low_cpu_mem_usage = True ,
* * config_kwargs
* * config_kwargs
)
)
@ -237,13 +253,13 @@ def prepare_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if stage != " sft " and training_args . predict_with_generate :
if stage != " sft " and training_args . predict_with_generate :
raise ValueError ( " `predict_with_generate` cannot be set as True in PT, RM and PPO stages. " )
raise ValueError ( " `predict_with_generate` cannot be set as True at PT, RM and PPO stages. " )
if training_args . do_train and training_args . predict_with_generate :
if training_args . do_train and training_args . predict_with_generate :
raise ValueError ( " `predict_with_generate` cannot be set as True while training. " )
raise ValueError ( " `predict_with_generate` cannot be set as True while training. " )
if training_args . do_predict and ( not training_args . predict_with_generate ) :
if training_args . do_predict and ( not training_args . predict_with_generate ) :
raise ValueError ( " Please enable `predict_with_generate` for saving model predictions. " )
raise ValueError ( " Please enable `predict_with_generate` to save model predictions. " )
if model_args . quantization_bit is not None and ( not training_args . do_train ) :
if model_args . quantization_bit is not None and ( not training_args . do_train ) :
logger . warning ( " Evaluating model in 4/8-bit mode may cause lower scores. " )
logger . warning ( " Evaluating model in 4/8-bit mode may cause lower scores. " )
@ -257,6 +273,14 @@ def prepare_args(
training_args . optim = " adamw_torch " if training_args . optim == " adamw_hf " else training_args . optim # suppress warning
training_args . optim = " adamw_torch " if training_args . optim == " adamw_hf " else training_args . optim # suppress warning
if model_args . quantization_bit is not None :
if training_args . fp16 :
finetuning_args . compute_dtype = torch . float16
elif training_args . bf16 :
finetuning_args . compute_dtype = torch . bfloat16
else :
finetuning_args . compute_dtype = torch . float32
# Log on each process the small summary:
# Log on each process the small summary:
logger . info (
logger . info (
f " Process rank: { training_args . local_rank } , device: { training_args . device } , n_gpu: { training_args . n_gpu } \n "
f " Process rank: { training_args . local_rank } , device: { training_args . device } , n_gpu: { training_args . n_gpu } \n "