|
@ -143,15 +143,24 @@ def load_pretrained( |
|
|
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ |
|
|
assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ |
|
|
"RM and PPO training can only be performed with LoRA method." |
|
|
"RM and PPO training can only be performed with LoRA method." |
|
|
|
|
|
|
|
|
|
|
|
config_kwargs = { |
|
|
|
|
|
"trust_remote_code": True, |
|
|
|
|
|
"cache_dir": model_args.cache_dir, |
|
|
|
|
|
"revision": model_args.model_revision, |
|
|
|
|
|
"use_auth_token": True if model_args.use_auth_token else None, |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_args.model_name_or_path, |
|
|
model_args.model_name_or_path, |
|
|
use_fast=model_args.use_fast_tokenizer, |
|
|
use_fast=model_args.use_fast_tokenizer, |
|
|
padding_side="left" |
|
|
padding_side="left", |
|
|
|
|
|
**config_kwargs |
|
|
) |
|
|
) |
|
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token |
|
|
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token |
|
|
|
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) |
|
|
|
|
|
|
|
|
# Quantization configurations (using bitsandbytes library). |
|
|
# Quantization configurations (using bitsandbytes library). |
|
|
config_kwargs = {} |
|
|
|
|
|
if model_args.quantization_bit is not None: |
|
|
if model_args.quantization_bit is not None: |
|
|
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." |
|
|
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." |
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1") |
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1") |
|
@ -162,23 +171,19 @@ def load_pretrained( |
|
|
|
|
|
|
|
|
config_kwargs["load_in_8bit"] = True |
|
|
config_kwargs["load_in_8bit"] = True |
|
|
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit |
|
|
config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit |
|
|
logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) |
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(model_args.model_name_or_path) |
|
|
|
|
|
|
|
|
|
|
|
# Load and prepare pretrained models (without valuehead). |
|
|
# Load and prepare pretrained models (without valuehead). |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_args.model_name_or_path, |
|
|
model_args.model_name_or_path, |
|
|
config=config, |
|
|
config=config, |
|
|
torch_dtype=torch.float16, # the model weights are float16 type |
|
|
torch_dtype=torch.float16, # the model weights are float16 type |
|
|
|
|
|
low_cpu_mem_usage=True, |
|
|
**config_kwargs |
|
|
**config_kwargs |
|
|
) |
|
|
) |
|
|
model = prepare_model_for_training(model) if is_trainable else model |
|
|
model = prepare_model_for_training(model) if is_trainable else model |
|
|
model = init_adapter(model, model_args, finetuning_args, is_trainable) |
|
|
model = init_adapter(model, model_args, finetuning_args, is_trainable) |
|
|
|
|
|
|
|
|
if not is_trainable: |
|
|
|
|
|
model.requires_grad_(False) # fix all model params |
|
|
|
|
|
|
|
|
|
|
|
if stage == "rm" or stage == "ppo": # add value head |
|
|
if stage == "rm" or stage == "ppo": # add value head |
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) |
|
|
model = AutoModelForCausalLMWithValueHead.from_pretrained(model) |
|
|
|
|
|
|
|
@ -194,6 +199,9 @@ def load_pretrained( |
|
|
if model_args.quantization_bit is not None: |
|
|
if model_args.quantization_bit is not None: |
|
|
model._is_int8_training_enabled = True |
|
|
model._is_int8_training_enabled = True |
|
|
|
|
|
|
|
|
|
|
|
if not is_trainable: |
|
|
|
|
|
model.requires_grad_(False) # fix all model params |
|
|
|
|
|
|
|
|
print_trainable_params(model) |
|
|
print_trainable_params(model) |
|
|
|
|
|
|
|
|
return model, tokenizer |
|
|
return model, tokenizer |
|
|