|
|
@ -4,15 +4,14 @@ import torch |
|
|
|
from tqdm import tqdm |
|
|
|
from typing import Callable, Dict, List, Literal, Optional, Tuple |
|
|
|
|
|
|
|
from transformers import Seq2SeqTrainingArguments |
|
|
|
from transformers.trainer import TrainerState |
|
|
|
from transformers import Seq2SeqTrainingArguments, TrainerState |
|
|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
|
|
|
|
|
from trl import PPOTrainer, AutoModelForCausalLMWithValueHead |
|
|
|
from trl.core import LengthSampler |
|
|
|
from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits |
|
|
|
|
|
|
|
from .peft_trainer import PeftTrainer |
|
|
|
from .peft_trainer import PeftTrainer, LogCallback |
|
|
|
|
|
|
|
from .config import FinetuningArguments |
|
|
|
|
|
|
@ -40,15 +39,41 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def |
|
|
|
}) |
|
|
|
|
|
|
|
|
|
|
|
def cast_layernorm_dtype( |
|
|
|
model: AutoModelForCausalLMWithValueHead, |
|
|
|
layer_norm_names: List[str] = ["layernorm"], # for chatglm setting |
|
|
|
layer_norm_params: Optional[Dict[str, torch.Tensor]] = None |
|
|
|
) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: |
|
|
|
|
|
|
|
layer_norm_state_dict = {} |
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
|
if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): |
|
|
|
if layer_norm_params is not None: |
|
|
|
param.data = layer_norm_params[name] # restore float32 weights |
|
|
|
else: |
|
|
|
layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability |
|
|
|
param.data = param.data.to(torch.float16) |
|
|
|
|
|
|
|
return model, layer_norm_state_dict |
|
|
|
|
|
|
|
|
|
|
|
class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): |
|
|
|
r""" |
|
|
|
Inherits PPOTrainer. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs): |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
training_args: Seq2SeqTrainingArguments, |
|
|
|
finetuning_args: FinetuningArguments, |
|
|
|
callbacks: List[LogCallback], |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
PPOTrainer.__init__(self, **kwargs) |
|
|
|
self.args = training_args |
|
|
|
self.finetuning_args = finetuning_args |
|
|
|
self.log_callback = callbacks[0] |
|
|
|
self.state = TrainerState() |
|
|
|
self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) |
|
|
|
|
|
|
@ -63,6 +88,11 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): |
|
|
|
num_train_epochs = self.args.num_train_epochs |
|
|
|
max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) |
|
|
|
|
|
|
|
self.state.max_steps = max_steps |
|
|
|
self.state.num_train_epochs = num_train_epochs |
|
|
|
self.state.is_local_process_zero = self.is_local_process_zero() |
|
|
|
self.state.is_world_process_zero = self.is_world_process_zero() |
|
|
|
|
|
|
|
if self.is_world_process_zero(): |
|
|
|
logger.info("***** Running training *****") |
|
|
|
logger.info(f" Num examples = {num_examples}") |
|
|
@ -144,6 +174,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): |
|
|
|
print(logs) |
|
|
|
logs["step"] = step |
|
|
|
self.state.log_history.append(logs) |
|
|
|
self.log_callback.on_log(self.args, self.state, None) |
|
|
|
loss_meter.reset() |
|
|
|
reward_meter.reset() |
|
|
|
|
|
|
@ -154,8 +185,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): |
|
|
|
def generate( |
|
|
|
self, |
|
|
|
inputs: Dict[str, torch.Tensor], |
|
|
|
length_sampler: Callable = None, |
|
|
|
return_prompt: bool = True, |
|
|
|
length_sampler: Optional[Callable] = None, |
|
|
|
return_prompt: Optional[bool] = True, |
|
|
|
**generation_kwargs, |
|
|
|
) -> torch.Tensor: |
|
|
|
r""" |
|
|
@ -163,6 +194,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): |
|
|
|
|
|
|
|
Subclass and override to inject custom behavior. |
|
|
|
""" |
|
|
|
self.model, layer_norm_params = cast_layernorm_dtype(self.model) |
|
|
|
|
|
|
|
if length_sampler is not None: |
|
|
|
generation_kwargs["max_new_tokens"] = length_sampler() |
|
|
|
|
|
|
@ -175,6 +208,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): |
|
|
|
if unwrapped_model.pretrained_model.generation_config._from_model_config: |
|
|
|
unwrapped_model.pretrained_model.generation_config._from_model_config = False |
|
|
|
|
|
|
|
self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) |
|
|
|
|
|
|
|
if not return_prompt and not self.is_encoder_decoder: |
|
|
|
return response[:, inputs["input_ids"].size(1):] |
|
|
|
return response |
|
|
|