| 
						
						
							
								
							
						
						
					 | 
					@ -4,15 +4,14 @@ import torch | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from tqdm import tqdm | 
					 | 
					 | 
					from tqdm import tqdm | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from typing import Callable, Dict, List, Literal, Optional, Tuple | 
					 | 
					 | 
					from typing import Callable, Dict, List, Literal, Optional, Tuple | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					from transformers import Seq2SeqTrainingArguments | 
					 | 
					 | 
					from transformers import Seq2SeqTrainingArguments, TrainerState | 
				
			
			
				
				
			
		
	
		
		
			
				
					 | 
					 | 
					from transformers.trainer import TrainerState | 
					 | 
					 | 
					 | 
				
			
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					from transformers.modeling_utils import PreTrainedModel | 
					 | 
					 | 
					from transformers.modeling_utils import PreTrainedModel | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from trl import PPOTrainer, AutoModelForCausalLMWithValueHead | 
					 | 
					 | 
					from trl import PPOTrainer, AutoModelForCausalLMWithValueHead | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from trl.core import LengthSampler | 
					 | 
					 | 
					from trl.core import LengthSampler | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits | 
					 | 
					 | 
					from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					from .peft_trainer import PeftTrainer | 
					 | 
					 | 
					from .peft_trainer import PeftTrainer, LogCallback | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					from .config import FinetuningArguments | 
					 | 
					 | 
					from .config import FinetuningArguments | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -40,15 +39,41 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    }) | 
					 | 
					 | 
					    }) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					def cast_layernorm_dtype( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        model: AutoModelForCausalLMWithValueHead, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        layer_norm_names: List[str] = ["layernorm"], # for chatglm setting | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        layer_norm_params: Optional[Dict[str, torch.Tensor]] = None | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    layer_norm_state_dict = {} | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    for name, param in model.named_parameters(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            if layer_norm_params is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                param.data = layer_norm_params[name] # restore float32 weights | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            else: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                param.data = param.data.to(torch.float16) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    return model, layer_norm_state_dict | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): | 
					 | 
					 | 
					class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    r""" | 
					 | 
					 | 
					    r""" | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    Inherits PPOTrainer. | 
					 | 
					 | 
					    Inherits PPOTrainer. | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    """ | 
					 | 
					 | 
					    """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					    def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs): | 
					 | 
					 | 
					    def __init__( | 
				
			
			
				
				
			
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            self, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            training_args: Seq2SeqTrainingArguments, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            finetuning_args: FinetuningArguments, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            callbacks: List[LogCallback], | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					            **kwargs | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					    ): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        PPOTrainer.__init__(self, **kwargs) | 
					 | 
					 | 
					        PPOTrainer.__init__(self, **kwargs) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.args = training_args | 
					 | 
					 | 
					        self.args = training_args | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.finetuning_args = finetuning_args | 
					 | 
					 | 
					        self.finetuning_args = finetuning_args | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.log_callback = callbacks[0] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.state = TrainerState() | 
					 | 
					 | 
					        self.state = TrainerState() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) | 
					 | 
					 | 
					        self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -63,6 +88,11 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        num_train_epochs = self.args.num_train_epochs | 
					 | 
					 | 
					        num_train_epochs = self.args.num_train_epochs | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) | 
					 | 
					 | 
					        max_steps = math.ceil(num_train_epochs * num_steps_per_epoch) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.state.max_steps = max_steps | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.state.num_train_epochs = num_train_epochs | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.state.is_local_process_zero = self.is_local_process_zero() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.state.is_world_process_zero = self.is_world_process_zero() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if self.is_world_process_zero(): | 
					 | 
					 | 
					        if self.is_world_process_zero(): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.info("***** Running training *****") | 
					 | 
					 | 
					            logger.info("***** Running training *****") | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            logger.info(f"  Num examples = {num_examples}") | 
					 | 
					 | 
					            logger.info(f"  Num examples = {num_examples}") | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
							
								
							
						
						
					 | 
					@ -144,6 +174,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                print(logs) | 
					 | 
					 | 
					                print(logs) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                logs["step"] = step | 
					 | 
					 | 
					                logs["step"] = step | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                self.state.log_history.append(logs) | 
					 | 
					 | 
					                self.state.log_history.append(logs) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					                self.log_callback.on_log(self.args, self.state, None) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                loss_meter.reset() | 
					 | 
					 | 
					                loss_meter.reset() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					                reward_meter.reset() | 
					 | 
					 | 
					                reward_meter.reset() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -154,8 +185,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    def generate( | 
					 | 
					 | 
					    def generate( | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            self, | 
					 | 
					 | 
					            self, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            inputs: Dict[str, torch.Tensor], | 
					 | 
					 | 
					            inputs: Dict[str, torch.Tensor], | 
				
			
			
		
	
		
		
			
				
					
					 | 
					 | 
					            length_sampler: Callable = None, | 
					 | 
					 | 
					            length_sampler: Optional[Callable] = None, | 
				
			
			
				
				
			
		
	
		
		
			
				
					
					 | 
					 | 
					            return_prompt: bool = True, | 
					 | 
					 | 
					            return_prompt: Optional[bool] = True, | 
				
			
			
				
				
			
		
	
		
		
	
		
		
	
		
		
			
				
					 | 
					 | 
					            **generation_kwargs, | 
					 | 
					 | 
					            **generation_kwargs, | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					    ) -> torch.Tensor: | 
					 | 
					 | 
					    ) -> torch.Tensor: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        r""" | 
					 | 
					 | 
					        r""" | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -163,6 +194,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        Subclass and override to inject custom behavior. | 
					 | 
					 | 
					        Subclass and override to inject custom behavior. | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        """ | 
					 | 
					 | 
					        """ | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.model, layer_norm_params = cast_layernorm_dtype(self.model) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if length_sampler is not None: | 
					 | 
					 | 
					        if length_sampler is not None: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            generation_kwargs["max_new_tokens"] = length_sampler() | 
					 | 
					 | 
					            generation_kwargs["max_new_tokens"] = length_sampler() | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
	
		
		
			
				
					| 
						
						
						
							
								
							
						
					 | 
					@ -175,6 +208,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if unwrapped_model.pretrained_model.generation_config._from_model_config: | 
					 | 
					 | 
					        if unwrapped_model.pretrained_model.generation_config._from_model_config: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            unwrapped_model.pretrained_model.generation_config._from_model_config = False | 
					 | 
					 | 
					            unwrapped_model.pretrained_model.generation_config._from_model_config = False | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					
 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					        self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					 | 
					 | 
					 | 
					
 | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        if not return_prompt and not self.is_encoder_decoder: | 
					 | 
					 | 
					        if not return_prompt and not self.is_encoder_decoder: | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					            return response[:, inputs["input_ids"].size(1):] | 
					 | 
					 | 
					            return response[:, inputs["input_ids"].size(1):] | 
				
			
			
		
	
		
		
			
				
					 | 
					 | 
					        return response | 
					 | 
					 | 
					        return response | 
				
			
			
		
	
	
		
		
			
				
					| 
						
							
								
							
						
						
						
					 | 
					
  |