From 0cee6ad67ffb06f0d7165a0284e39f510a2abc36 Mon Sep 17 00:00:00 2001 From: hiyouga <hiyouga@buaa.edu.cn> Date: Thu, 15 Jun 2023 16:02:01 +0800 Subject: [PATCH] support baichuan model --- README.md | 1 + src/utils/common.py | 2 ++ src/utils/other.py | 8 +++++++- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index b24b2f6..0b7dd9a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) - [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) +- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B) ## Supported Training Approaches diff --git a/src/utils/common.py b/src/utils/common.py index 38ae830..094516e 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -170,6 +170,8 @@ def load_pretrained( **config_kwargs ) tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token + if tokenizer.pad_token_id == 64000: + tokenizer.pad_token_id = 0 # for baichuan model (need fix) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) is_mergeable = True diff --git a/src/utils/other.py b/src/utils/other.py index 838b617..88bf081 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -83,7 +83,13 @@ def prepare_model_for_training( param.data = param.data.to(torch.float32) if use_gradient_checkpointing: - model.enable_input_require_grads() + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + model.gradient_checkpointing_enable() model.config.use_cache = False # turn off when gradient checkpointing is enabled