|
@ -172,16 +172,13 @@ def load_pretrained( |
|
|
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") |
|
|
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") |
|
|
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") |
|
|
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") |
|
|
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") |
|
|
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") |
|
|
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible |
|
|
|
|
|
cuda = get_cuda_lib_handle() |
|
|
|
|
|
cc = get_compute_capability(cuda) |
|
|
|
|
|
assert is_cublasLt_compatible(cc), "The current GPU(s) is incompatible with quantization." |
|
|
|
|
|
|
|
|
|
|
|
config_kwargs["load_in_8bit"] = True |
|
|
config_kwargs["load_in_8bit"] = True |
|
|
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit |
|
|
|
|
|
is_mergeable = False |
|
|
is_mergeable = False |
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) |
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) |
|
|
|
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None or (not is_trainable): # automatically load in CUDA |
|
|
|
|
|
config_kwargs["device_map"] = "auto" |
|
|
|
|
|
|
|
|
# Load and prepare pretrained models (without valuehead). |
|
|
# Load and prepare pretrained models (without valuehead). |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_args.model_name_or_path, |
|
|
model_args.model_name_or_path, |
|
|