diff --git a/README.md b/README.md index b24b2f6..0b7dd9a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B) - [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B) +- [baichuan](https://huggingface.co/baichuan-inc/baichuan-7B) (7B) ## Supported Training Approaches diff --git a/src/utils/common.py b/src/utils/common.py index 38ae830..094516e 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -170,6 +170,8 @@ def load_pretrained( **config_kwargs ) tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the token + if tokenizer.pad_token_id == 64000: + tokenizer.pad_token_id = 0 # for baichuan model (need fix) config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs) is_mergeable = True diff --git a/src/utils/other.py b/src/utils/other.py index 838b617..88bf081 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -83,7 +83,13 @@ def prepare_model_for_training( param.data = param.data.to(torch.float32) if use_gradient_checkpointing: - model.enable_input_require_grads() + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + model.gradient_checkpointing_enable() model.config.use_cache = False # turn off when gradient checkpointing is enabled