|
|
@ -46,7 +46,8 @@ from .other import ( |
|
|
|
) |
|
|
|
|
|
|
|
check_min_version("4.29.1") |
|
|
|
require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0") |
|
|
|
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") |
|
|
|
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") |
|
|
|
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") |
|
|
|
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") |
|
|
|
|
|
|
@ -84,8 +85,7 @@ def init_adapter( |
|
|
|
param.data = param.data.to(torch.float32) |
|
|
|
|
|
|
|
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: |
|
|
|
if len(model_args.checkpoint_dir) > 1: |
|
|
|
logger.warning("Only LoRA tuning accepts multiple checkpoints.") |
|
|
|
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." |
|
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods |
|
|
|
|
|
|
|
if finetuning_args.finetuning_type == "lora": |
|
|
@ -154,8 +154,7 @@ def load_pretrained( |
|
|
|
config_kwargs = {} |
|
|
|
if model_args.quantization_bit is not None: |
|
|
|
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." |
|
|
|
|
|
|
|
require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.") |
|
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1") |
|
|
|
from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible |
|
|
|
cuda = get_cuda_lib_handle() |
|
|
|
cc = get_compute_capability(cuda) |
|
|
@ -179,7 +178,6 @@ def load_pretrained( |
|
|
|
|
|
|
|
if not is_trainable: |
|
|
|
model.requires_grad_(False) # fix all model params |
|
|
|
model = model.half() # cast all params to float16 for inference |
|
|
|
|
|
|
|
if stage == "rm" or stage == "ppo": # add value head |
|
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) |
|
|
|