From 0c9fda01e3c61727c939efd9d9398f657a2d69b6 Mon Sep 17 00:00:00 2001
From: hiyouga <hiyouga@buaa.edu.cn>
Date: Sun, 28 May 2023 21:30:28 +0800
Subject: [PATCH] use fp16 model, add logcallback

---
 src/train_ppo.py          |  2 ++
 src/train_rm.py           |  2 ++
 src/train_sft.py          |  4 +++-
 src/utils/__init__.py     |  2 ++
 src/utils/common.py       | 12 +++++++++--
 src/utils/peft_trainer.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++-
 src/utils/ppo.py          | 47 +++++++++++++++++++++++++++++++++++------
 7 files changed, 112 insertions(+), 10 deletions(-)

diff --git a/src/train_ppo.py b/src/train_ppo.py
index 41f89a5..f6e57c0 100644
--- a/src/train_ppo.py
+++ b/src/train_ppo.py
@@ -17,6 +17,7 @@ from utils import (
     preprocess_data,
     DataCollatorForLLaMA,
     PPOTrainerForLLaMA,
+    LogCallback,
     plot_loss
 )
 
@@ -54,6 +55,7 @@ def main():
     ppo_trainer = PPOTrainerForLLaMA(
         training_args=training_args,
         finetuning_args=finetuning_args,
+        callbacks=[LogCallback()],
         config=ppo_config,
         model=model,
         ref_model=None,
diff --git a/src/train_rm.py b/src/train_rm.py
index dd544f3..ecd7f71 100644
--- a/src/train_rm.py
+++ b/src/train_rm.py
@@ -12,6 +12,7 @@ from utils import (
     preprocess_data,
     PairwiseDataCollatorForLLaMA,
     PairwiseTrainerForLLaMA,
+    LogCallback,
     plot_loss
 )
 
@@ -43,6 +44,7 @@ def main():
         args=training_args,
         tokenizer=tokenizer,
         data_collator=data_collator,
+        callbacks=[LogCallback()],
         **trainer_kwargs
     )
 
diff --git a/src/train_sft.py b/src/train_sft.py
index d34a8e4..3bc0f85 100644
--- a/src/train_sft.py
+++ b/src/train_sft.py
@@ -12,6 +12,7 @@ from utils import (
     DataCollatorForLLaMA,
     Seq2SeqTrainerForLLaMA,
     ComputeMetrics,
+    LogCallback,
     get_logits_processor,
     plot_loss
 )
@@ -49,6 +50,7 @@ def main():
         args=training_args,
         tokenizer=tokenizer,
         data_collator=data_collator,
+        callbacks=[LogCallback()],
         compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
         **trainer_kwargs
     )
@@ -57,7 +59,7 @@ def main():
     gen_kwargs = {
         "do_sample": True,
         "top_p": 0.7,
-        "max_length": data_args.max_source_length + data_args.max_target_length + 1,
+        "max_new_tokens": data_args.max_target_length + 1,
         "temperature": 0.95,
         "logits_processor": get_logits_processor()
     }
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
index c19e82a..a9104cc 100644
--- a/src/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -7,6 +7,8 @@ from .common import (
 
 from .data_collator import DataCollatorForLLaMA
 
+from .peft_trainer import LogCallback
+
 from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
 from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
 from .ppo import PPOTrainerForLLaMA
diff --git a/src/utils/common.py b/src/utils/common.py
index db6bfd2..2798124 100644
--- a/src/utils/common.py
+++ b/src/utils/common.py
@@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Tuple
 
 import transformers
 from transformers import (
+    LlamaConfig,
     LlamaForCausalLM,
     LlamaTokenizer,
     HfArgumentParser,
@@ -151,7 +152,7 @@ def load_pretrained(
         use_fast=model_args.use_fast_tokenizer,
         padding_side="left"
     )
-    tokenizer.pad_token_id = 0 # set as the <unk> token
+    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
 
     # Quantization configurations (using bitsandbytes library).
     config_kwargs = {}
@@ -168,8 +169,15 @@ def load_pretrained(
         config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
         logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
 
+    config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
+
     # Load and prepare pretrained models (without valuehead).
-    model = LlamaForCausalLM.from_pretrained(model_args.model_name_or_path, **config_kwargs)
+    model = LlamaForCausalLM.from_pretrained(
+        model_args.model_name_or_path,
+        config=config,
+        torch_dtype=torch.float16, # the llama weights are float16 type
+        **config_kwargs
+    )
     model = prepare_model_for_training(model) if is_trainable else model
     model = init_adapter(model, model_args, finetuning_args, is_trainable)
 
diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py
index 57d54a8..0afe4fb 100644
--- a/src/utils/peft_trainer.py
+++ b/src/utils/peft_trainer.py
@@ -1,8 +1,18 @@
 import os
+import json
+import time
 import torch
 from typing import Dict, Optional
+from datetime import timedelta
+
+from transformers import (
+    Seq2SeqTrainer,
+    TrainerCallback,
+    TrainerControl,
+    TrainerState,
+    TrainingArguments
+)
 
-from transformers import Seq2SeqTrainer
 from transformers.trainer import TRAINING_ARGS_NAME
 from transformers.modeling_utils import unwrap_model
 
@@ -23,6 +33,44 @@ from .other import (
 logger = get_logger(__name__)
 
 
+class LogCallback(TrainerCallback):
+    r"""
+    TrainerCallback includes the state function during training, for more details refer to the TrainerCallback class.
+    The on_log function primarily collects process parameters during training, such as training loss, learning rate,
+    and training epochs, as well as progress parameters like the current percentage progress and estimated remaining
+    time. Every time a log is triggered, a new record is appended to the file "messages.log" for dynamic visualization
+    purposes.
+    """
+
+    def __init__(self):
+        self.start_time = time.time()
+
+    def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
+        r"""
+        Event called after logging the last logs.
+        """
+        cur_time = time.time()
+        cur_steps = state.log_history[-1].get("step")
+        elapsed_time = cur_time - self.start_time
+        avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
+        remaining_steps = state.max_steps - cur_steps
+        remaining_time = remaining_steps * avg_time_per_step
+        log_dict = {
+            "current_steps": cur_steps,
+            "total_steps": state.max_steps,
+            "loss": state.log_history[-1].get("loss", None),
+            "reward": state.log_history[-1].get("reward", None),
+            "learning_rate": state.log_history[-1].get("learning_rate", None),
+            "epoch": state.log_history[-1].get("epoch", None),
+            "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
+            "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
+            "remaining_time": str(timedelta(seconds=int(remaining_time)))
+        }
+        os.makedirs(args.output_dir, exist_ok=True)
+        with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a") as f:
+            f.write(json.dumps(log_dict) + "\n")
+
+
 class PeftTrainer(Seq2SeqTrainer):
     r"""
     Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
@@ -31,6 +79,9 @@ class PeftTrainer(Seq2SeqTrainer):
     def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
         super().__init__(**kwargs)
         self.finetuning_args = finetuning_args
+        if os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
+            logger.warning("Previous log file in this folder will be deleted.")
+            os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
 
     def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
         r"""
diff --git a/src/utils/ppo.py b/src/utils/ppo.py
index 85c6950..8a06887 100644
--- a/src/utils/ppo.py
+++ b/src/utils/ppo.py
@@ -4,15 +4,14 @@ import torch
 from tqdm import tqdm
 from typing import Callable, Dict, List, Literal, Optional, Tuple
 
-from transformers import Seq2SeqTrainingArguments
-from transformers.trainer import TrainerState
+from transformers import Seq2SeqTrainingArguments, TrainerState
 from transformers.modeling_utils import PreTrainedModel
 
 from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
 from trl.core import LengthSampler
 from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
 
-from .peft_trainer import PeftTrainer
+from .peft_trainer import PeftTrainer, LogCallback
 
 from .config import FinetuningArguments
 
@@ -40,15 +39,41 @@ def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["def
     })
 
 
+def cast_layernorm_dtype(
+        model: AutoModelForCausalLMWithValueHead,
+        layer_norm_names: List[str] = ["layernorm"], # for chatglm setting
+        layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
+) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
+
+    layer_norm_state_dict = {}
+
+    for name, param in model.named_parameters():
+        if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
+            if layer_norm_params is not None:
+                param.data = layer_norm_params[name] # restore float32 weights
+            else:
+                layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
+                param.data = param.data.to(torch.float16)
+
+    return model, layer_norm_state_dict
+
+
 class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
     r"""
     Inherits PPOTrainer.
     """
 
-    def __init__(self, training_args: Seq2SeqTrainingArguments, finetuning_args: FinetuningArguments, **kwargs):
+    def __init__(
+            self,
+            training_args: Seq2SeqTrainingArguments,
+            finetuning_args: FinetuningArguments,
+            callbacks: List[LogCallback],
+            **kwargs
+    ):
         PPOTrainer.__init__(self, **kwargs)
         self.args = training_args
         self.finetuning_args = finetuning_args
+        self.log_callback = callbacks[0]
         self.state = TrainerState()
         self.data_collator = self.accelerator.prepare(kwargs["data_collator"])
 
@@ -63,6 +88,11 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
         num_train_epochs = self.args.num_train_epochs
         max_steps = math.ceil(num_train_epochs * num_steps_per_epoch)
 
+        self.state.max_steps = max_steps
+        self.state.num_train_epochs = num_train_epochs
+        self.state.is_local_process_zero = self.is_local_process_zero()
+        self.state.is_world_process_zero = self.is_world_process_zero()
+
         if self.is_world_process_zero():
             logger.info("***** Running training *****")
             logger.info(f"  Num examples = {num_examples}")
@@ -144,6 +174,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
                 print(logs)
                 logs["step"] = step
                 self.state.log_history.append(logs)
+                self.log_callback.on_log(self.args, self.state, None)
                 loss_meter.reset()
                 reward_meter.reset()
 
@@ -154,8 +185,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
     def generate(
             self,
             inputs: Dict[str, torch.Tensor],
-            length_sampler: Callable = None,
-            return_prompt: bool = True,
+            length_sampler: Optional[Callable] = None,
+            return_prompt: Optional[bool] = True,
             **generation_kwargs,
     ) -> torch.Tensor:
         r"""
@@ -163,6 +194,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
 
         Subclass and override to inject custom behavior.
         """
+        self.model, layer_norm_params = cast_layernorm_dtype(self.model)
+
         if length_sampler is not None:
             generation_kwargs["max_new_tokens"] = length_sampler()
 
@@ -175,6 +208,8 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
         if unwrapped_model.pretrained_model.generation_config._from_model_config:
             unwrapped_model.pretrained_model.generation_config._from_model_config = False
 
+        self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
+
         if not return_prompt and not self.is_encoder_decoder:
             return response[:, inputs["input_ids"].size(1):]
         return response