From 0cee6ad67ffb06f0d7165a0284e39f510a2abc36 Mon Sep 17 00:00:00 2001
From: hiyouga <hiyouga@buaa.edu.cn>
Date: Thu, 15 Jun 2023 16:02:01 +0800
Subject: [PATCH] support baichuan model

---
 README.md           | 1 +
 src/utils/common.py | 2 ++
 src/utils/other.py  | 8 +++++++-
 3 files changed, 10 insertions(+), 1 deletion(-)

diff --git a/README.md b/README.md
index b24b2f6..0b7dd9a 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@
 
 - [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
 - [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
+- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B)
 
 ## Supported Training Approaches
 
diff --git a/src/utils/common.py b/src/utils/common.py
index 38ae830..094516e 100644
--- a/src/utils/common.py
+++ b/src/utils/common.py
@@ -170,6 +170,8 @@ def load_pretrained(
         **config_kwargs
     )
     tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
+    if tokenizer.pad_token_id == 64000:
+        tokenizer.pad_token_id = 0 # for baichuan model (need fix)
 
     config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
     is_mergeable = True
diff --git a/src/utils/other.py b/src/utils/other.py
index 838b617..88bf081 100644
--- a/src/utils/other.py
+++ b/src/utils/other.py
@@ -83,7 +83,13 @@ def prepare_model_for_training(
             param.data = param.data.to(torch.float32)
 
     if use_gradient_checkpointing:
-        model.enable_input_require_grads()
+        if hasattr(model, "enable_input_require_grads"):
+            model.enable_input_require_grads()
+        else:
+            def make_inputs_require_grad(module, input, output):
+                output.requires_grad_(True)
+            model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
         model.gradient_checkpointing_enable()
         model.config.use_cache = False # turn off when gradient checkpointing is enabled