diff --git a/README.md b/README.md
index ea162cb..3579c81 100644
--- a/README.md
+++ b/README.md
@@ -5,9 +5,65 @@
 ![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning)
 ![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue)
 
+## Changelog
+
+[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
+
+## Supported Models
+
+- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B)
+- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B)
+
+## Supported Training Approach
+
+- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf)
+  - Full-parameter training
+  - Selected-parameter training
+  - [LoRA](https://arxiv.org/abs/2106.09685)
+- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
+  - Full-parameter training
+  - Selected-parameter training
+  - [LoRA](https://arxiv.org/abs/2106.09685)
+- [RLHF](https://arxiv.org/abs/2203.02155)
+  - [LoRA](https://arxiv.org/abs/2106.09685)
+
+## Provided Datasets
+
+- For pre-training:
+  - [Wiki Demo](data/wiki_demo.txt)
+- For supervised fine-tuning:
+  - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca)
+  - [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca)
+  - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+  - [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN)
+  - [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN)
+  - [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN)
+  - [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M)
+  - [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M)
+  - [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M)
+  - [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset)
+  - [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M)
+  - [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k)
+  - [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT)
+  - [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa)
+  - [UltraChat](https://github.com/thunlp/UltraChat)
+- For reward model training:
+  - [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf)
+  - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+  - [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM)
+
+Please refer to [data/README.md](data/README.md) for details.
+
+Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands.
+
+```bash
+pip install --upgrade huggingface_hub
+huggingface-cli login
+```
+
 ## Requirement
 
-- Python 3.8+ and PyTorch 1.13.1
+- Python 3.8+ and PyTorch 1.13.1+
 - 🤗Transformers, Datasets, Accelerate, PEFT and TRL
 - protobuf, cpm_kernels and sentencepiece
 - jieba, rouge_chinese and nltk (used at evaluation)
@@ -36,10 +92,10 @@ pip install -r requirements.txt
 ### LLaMA Weights Preparation
 
 1. Download the weights of the LLaMA models.
-2. Convert them to HF format using this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py)
+2. Convert them to HF format using the following command.
 
-```python
-python convert_llama_weights_to_hf.py \
+```bash
+python -m transformers.models.llama.convert_llama_weights_to_hf \
     --input_dir path_to_llama_weights --model_size 7B --output_dir path_to_llama_model
 ```
 
@@ -177,7 +233,11 @@ python src/export_model.py \
 
 ## License
 
-This repository is licensed under the [Apache-2.0 License](LICENSE). Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) to use the LLaMA model.
+This repository is licensed under the [Apache-2.0 License](LICENSE).
+
+Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) to use the LLaMA models.
+
+Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models.
 
 ## Citation
 
diff --git a/src/cli_demo.py b/src/cli_demo.py
index 416d660..fd24e99 100644
--- a/src/cli_demo.py
+++ b/src/cli_demo.py
@@ -1,10 +1,10 @@
 # coding=utf-8
-# Implements stream chat in command line for LLaMA fine-tuned with PEFT.
+# Implements stream chat in command line for fine-tuned models.
 # Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint
 
 
 import torch
-from utils import ModelArguments, auto_configure_device_map, load_pretrained
+from utils import ModelArguments, load_pretrained
 from transformers import HfArgumentParser
 
 
@@ -12,10 +12,11 @@ def main():
 
     parser = HfArgumentParser(ModelArguments)
     model_args, = parser.parse_args_into_dataclasses()
+    model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
     model, tokenizer = load_pretrained(model_args)
     if torch.cuda.device_count() > 1:
-        from accelerate import dispatch_model
-        device_map = auto_configure_device_map(torch.cuda.device_count())
+        from accelerate import dispatch_model, infer_auto_device_map
+        device_map = infer_auto_device_map(model)
         model = dispatch_model(model, device_map)
     else:
         model = model.cuda()
@@ -47,7 +48,7 @@ def main():
         return response, history
 
     history = []
-    print("欢迎使用 LLaMA-7B 模型,输入内容即可对话,clear清空对话历史,stop终止程序")
+    print("欢迎使用 {} 模型,输入内容即可对话,clear清空对话历史,stop终止程序".format(model_name))
     while True:
         try:
             query = input("\nInput: ")
@@ -65,7 +66,7 @@ def main():
             continue
 
         response, history = predict(query, history)
-        print("LLaMA-7B:", response)
+        print("{}:".format(model_name), response)
 
 
 if __name__ == "__main__":
diff --git a/src/export_model.py b/src/export_model.py
index b62f10f..9ba361c 100644
--- a/src/export_model.py
+++ b/src/export_model.py
@@ -1,5 +1,5 @@
 # coding=utf-8
-# Exports the fine-tuned LLaMA model.
+# Exports the fine-tuned model.
 # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
 
 
diff --git a/src/train_ppo.py b/src/train_ppo.py
index f5b3c09..672dd8a 100644
--- a/src/train_ppo.py
+++ b/src/train_ppo.py
@@ -1,5 +1,5 @@
 # coding=utf-8
-# Implements parameter-efficient PPO training of fine-tuned LLaMA.
+# Implements parameter-efficient PPO training of fine-tuned models.
 # This code is inspired by:
 # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
 
@@ -15,8 +15,8 @@ from utils import (
     prepare_data,
     load_pretrained,
     preprocess_data,
-    DataCollatorForLLaMA,
-    PPOTrainerForLLaMA,
+    DynamicDataCollatorWithPadding,
+    PPOPeftTrainer,
     LogCallback,
     plot_loss
 )
@@ -29,7 +29,7 @@ def main():
     dataset = prepare_data(model_args, data_args)
     model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo")
     dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo")
-    data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model)
+    data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model)
 
     ppo_config = PPOConfig(
         model_name=model_args.model_name_or_path,
@@ -52,7 +52,7 @@ def main():
     )
 
     # Initialize our Trainer
-    ppo_trainer = PPOTrainerForLLaMA(
+    ppo_trainer = PPOPeftTrainer(
         training_args=training_args,
         finetuning_args=finetuning_args,
         callbacks=[LogCallback()],
diff --git a/src/train_pt.py b/src/train_pt.py
index ae4adcf..af88cb6 100644
--- a/src/train_pt.py
+++ b/src/train_pt.py
@@ -1,5 +1,5 @@
 # coding=utf-8
-# Implements several parameter-efficient pre-training method for LLaMA.
+# Implements several parameter-efficient pre-training method.
 # This code is inspired by
 # https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py
 
@@ -10,7 +10,7 @@ from utils import (
     prepare_args,
     prepare_data,
     preprocess_data,
-    DataCollatorForLLaMA,
+    DynamicDataCollatorWithPadding,
     PeftTrainer,
     LogCallback,
     plot_loss
@@ -24,7 +24,7 @@ def main():
     dataset = prepare_data(model_args, data_args)
     model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt")
     dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt")
-    data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
+    data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
 
     # Split the dataset
     if training_args.do_train:
diff --git a/src/train_rm.py b/src/train_rm.py
index 5f5afee..8b51f7b 100644
--- a/src/train_rm.py
+++ b/src/train_rm.py
@@ -1,5 +1,5 @@
 # coding=utf-8
-# Implements parameter-efficient training of a reward model based on LLaMA.
+# Implements parameter-efficient training of reward models.
 # This code is inspired by:
 # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
 # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
@@ -10,8 +10,8 @@ from utils import (
     prepare_data,
     load_pretrained,
     preprocess_data,
-    PairwiseDataCollatorForLLaMA,
-    PairwiseTrainerForLLaMA,
+    PairwiseDataCollatorWithPadding,
+    PairwisePeftTrainer,
     LogCallback,
     plot_loss
 )
@@ -23,7 +23,7 @@ def main():
     dataset = prepare_data(model_args, data_args)
     model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm")
     dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm")
-    data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model)
+    data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model)
 
     training_args.remove_unused_columns = False # important for pairwise dataset
 
@@ -38,7 +38,7 @@ def main():
         trainer_kwargs = {"eval_dataset": dataset}
 
     # Initialize our Trainer
-    trainer = PairwiseTrainerForLLaMA(
+    trainer = PairwisePeftTrainer(
         finetuning_args=finetuning_args,
         model=model,
         args=training_args,
diff --git a/src/train_sft.py b/src/train_sft.py
index 16c91f0..29c593f 100644
--- a/src/train_sft.py
+++ b/src/train_sft.py
@@ -1,5 +1,5 @@
 # coding=utf-8
-# Implements several parameter-efficient supervised fine-tuning method for LLaMA.
+# Implements several parameter-efficient supervised fine-tuning method.
 # This code is inspired by
 # https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py
 
@@ -9,8 +9,8 @@ from utils import (
     prepare_args,
     prepare_data,
     preprocess_data,
-    DataCollatorForLLaMA,
-    Seq2SeqTrainerForLLaMA,
+    DynamicDataCollatorWithPadding,
+    Seq2SeqPeftTrainer,
     ComputeMetrics,
     LogCallback,
     get_logits_processor,
@@ -25,7 +25,7 @@ def main():
     dataset = prepare_data(model_args, data_args)
     model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft")
     dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft")
-    data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss)
+    data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss)
 
     # Override the decoding parameters of Seq2SeqTrainer
     training_args.generation_max_length = training_args.generation_max_length if \
@@ -44,7 +44,7 @@ def main():
         trainer_kwargs = {"eval_dataset": dataset}
 
     # Initialize our Trainer
-    trainer = Seq2SeqTrainerForLLaMA(
+    trainer = Seq2SeqPeftTrainer(
         finetuning_args=finetuning_args,
         model=model,
         args=training_args,
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
index d25810c..680975a 100644
--- a/src/utils/__init__.py
+++ b/src/utils/__init__.py
@@ -5,13 +5,13 @@ from .common import (
     preprocess_data
 )
 
-from .data_collator import DataCollatorForLLaMA
+from .data_collator import DynamicDataCollatorWithPadding
 
 from .peft_trainer import PeftTrainer, LogCallback
 
-from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA
-from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA
-from .ppo import PPOTrainerForLLaMA
+from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
+from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
+from .ppo import PPOPeftTrainer
 
 from .config import ModelArguments
-from .other import auto_configure_device_map, get_logits_processor, plot_loss
+from .other import get_logits_processor, plot_loss
diff --git a/src/utils/common.py b/src/utils/common.py
index 4ba573e..9137e54 100644
--- a/src/utils/common.py
+++ b/src/utils/common.py
@@ -7,9 +7,9 @@ from typing import List, Literal, Optional, Tuple
 
 import transformers
 from transformers import (
-    LlamaConfig,
-    LlamaForCausalLM,
-    LlamaTokenizer,
+    AutoConfig,
+    AutoModelForCausalLM,
+    AutoTokenizer,
     HfArgumentParser,
     Seq2SeqTrainingArguments
 )
@@ -151,7 +151,7 @@ def load_pretrained(
     assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \
         "RM and PPO training can only be performed with LoRA method."
 
-    tokenizer = LlamaTokenizer.from_pretrained(
+    tokenizer = AutoTokenizer.from_pretrained(
         model_args.model_name_or_path,
         use_fast=model_args.use_fast_tokenizer,
         padding_side="left"
@@ -173,13 +173,13 @@ def load_pretrained(
         config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit
         logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
 
-    config = LlamaConfig.from_pretrained(model_args.model_name_or_path)
+    config = AutoConfig.from_pretrained(model_args.model_name_or_path)
 
     # Load and prepare pretrained models (without valuehead).
-    model = LlamaForCausalLM.from_pretrained(
+    model = AutoModelForCausalLM.from_pretrained(
         model_args.model_name_or_path,
         config=config,
-        torch_dtype=torch.float16, # the llama weights are float16 type
+        torch_dtype=torch.float16, # the model weights are float16 type
         **config_kwargs
     )
     model = prepare_model_for_training(model) if is_trainable else model
@@ -245,7 +245,7 @@ def prepare_args(
         logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
 
     if training_args.do_train and (not training_args.fp16):
-        logger.warning("We recommend enable fp16 mixed precision training for LLaMA.")
+        logger.warning("We recommend enable fp16 mixed precision training.")
 
     if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
         logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
diff --git a/src/utils/config.py b/src/utils/config.py
index 579905a..b4a7355 100644
--- a/src/utils/config.py
+++ b/src/utils/config.py
@@ -12,6 +12,12 @@ class DatasetAttr:
     file_name: Optional[str] = None
     file_sha1: Optional[str] = None
 
+    def __repr__(self) -> str:
+        if self.dataset_name is not None:
+            return self.dataset_name
+        else:
+            return self.file_name
+
     def __post_init__(self):
         self.prompt_column = "instruction"
         self.query_column = "input"
@@ -161,9 +167,11 @@ class FinetuningArguments:
         default=3,
         metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
     )
-    name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
+    name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field(
         default="mlp",
-        metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
+        metadata={"help": "Name of trainable modules for Freeze fine-tuning. \
+                  LLaMA choices: [\"mlp\", \"self_attn\"], \
+                  BLOOM choices: [\"mlp\", \"self_attention\"]"}
     )
     lora_rank: Optional[int] = field(
         default=8,
@@ -171,7 +179,7 @@ class FinetuningArguments:
     )
     lora_alpha: Optional[float] = field(
         default=32.0,
-        metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
+        metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."}
     )
     lora_dropout: Optional[float] = field(
         default=0.1,
@@ -179,7 +187,9 @@ class FinetuningArguments:
     )
     lora_target: Optional[str] = field(
         default="q_proj,v_proj",
-        metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."}
+        metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \
+                  LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"mlp\"], \
+                  BLOOM choices: [\"query_key_value\", \"dense\", \"mlp\"]"}
     )
 
     def __post_init__(self):
@@ -191,11 +201,7 @@ class FinetuningArguments:
         else: # fine-tuning the first n layers if num_layer_trainable < 0
             trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
 
-        if self.name_module_trainable == "mlp":
-            self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids]
-        elif self.name_module_trainable == "qkv":
-            self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \
-                                     for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids]
+        self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids]
 
         assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
 
diff --git a/src/utils/data_collator.py b/src/utils/data_collator.py
index 2932f2b..6a6de42 100644
--- a/src/utils/data_collator.py
+++ b/src/utils/data_collator.py
@@ -9,9 +9,9 @@ from transformers.tokenization_utils import PreTrainedTokenizer
 from .other import IGNORE_INDEX
 
 
-class DataCollatorForLLaMA(DataCollatorWithPadding):
+class DynamicDataCollatorWithPadding(DataCollatorWithPadding):
     r"""
-    Data collator for LLaMA. It is capable of dynamically padding for batched data.
+    Inherits DataCollatorWithPadding. It is capable of dynamically padding for batched data.
     """
     def __init__(
             self,
diff --git a/src/utils/other.py b/src/utils/other.py
index 8a30dd3..2008fb8 100644
--- a/src/utils/other.py
+++ b/src/utils/other.py
@@ -75,7 +75,7 @@ def prepare_model_for_training(
         model: PreTrainedModel,
         output_embedding_layer_name: Optional[str] = "lm_head",
         use_gradient_checkpointing: Optional[bool] = True,
-        layer_norm_names: Optional[List[str]] = ["norm"] # for LLaMA setting
+        layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting
 ) -> PreTrainedModel:
 
     for name, param in model.named_parameters():
@@ -143,29 +143,6 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -
     model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
 
 
-def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
-    r"""
-    Configures device map for LLaMA.
-
-    Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8
-    """
-    num_layers = 28
-    layers_per_gpu = 30 / num_gpus
-    device_map = {"model.embed_tokens": 0, "model.norm": 0, "lm_head": 0}
-    added_layers = 2
-    target_gpu = 0
-
-    for i in range(num_layers):
-        if added_layers >= layers_per_gpu:
-            target_gpu += 1
-            added_layers = 0
-        assert target_gpu < num_gpus
-        device_map[f"model.layers.{i}"] = target_gpu
-        added_layers += 1
-
-    return device_map
-
-
 def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]:
     """
     EMA implementation according to TensorBoard.
diff --git a/src/utils/pairwise.py b/src/utils/pairwise.py
index 3c2aa21..317599b 100644
--- a/src/utils/pairwise.py
+++ b/src/utils/pairwise.py
@@ -1,7 +1,7 @@
 import torch
 from typing import Dict, Sequence, Union
 
-from .data_collator import DataCollatorForLLaMA
+from .data_collator import DynamicDataCollatorWithPadding
 
 from .peft_trainer import PeftTrainer
 
@@ -10,7 +10,7 @@ from .other import get_logger
 logger = get_logger(__name__)
 
 
-class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
+class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
     r"""
     Data collator for pairwise data.
     """
@@ -26,7 +26,7 @@ class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA):
         return super().__call__(features)
 
 
-class PairwiseTrainerForLLaMA(PeftTrainer):
+class PairwisePeftTrainer(PeftTrainer):
     r"""
     Inherits PeftTrainer to compute pairwise loss.
     """
diff --git a/src/utils/ppo.py b/src/utils/ppo.py
index 7a69c43..e279c19 100644
--- a/src/utils/ppo.py
+++ b/src/utils/ppo.py
@@ -58,7 +58,7 @@ def cast_layernorm_dtype(
     return model, layer_norm_state_dict
 
 
-class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
+class PPOPeftTrainer(PPOTrainer, PeftTrainer):
     r"""
     Inherits PPOTrainer.
     """
@@ -130,7 +130,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer):
                 unwrapped_model.gradient_checkpointing_disable()
                 unwrapped_model.config.use_cache = True
 
-                # Get response from LLaMA
+                # Get response from model
                 query_tensors: torch.Tensor = batch["input_ids"]
                 response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs)
 
diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py
index 4a48393..47bbcc1 100644
--- a/src/utils/seq2seq.py
+++ b/src/utils/seq2seq.py
@@ -22,7 +22,7 @@ logger = get_logger(__name__)
 @dataclass
 class ComputeMetrics:
     r"""
-    Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA.
+    Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
 
     Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307
     """
@@ -62,7 +62,7 @@ class ComputeMetrics:
         return {k: float(np.mean(v)) for k, v in score_dict.items()}
 
 
-class Seq2SeqTrainerForLLaMA(PeftTrainer):
+class Seq2SeqPeftTrainer(PeftTrainer):
     r"""
     Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
     """
diff --git a/src/web_demo.py b/src/web_demo.py
index 5bbed0b..5129ea8 100644
--- a/src/web_demo.py
+++ b/src/web_demo.py
@@ -1,5 +1,5 @@
 # coding=utf-8
-# Implements user interface in browser for LLaMA fine-tuned with PEFT.
+# Implements user interface in browser for fine-tuned models.
 # Usage: python web_demo.py --checkpoint_dir path_to_checkpoint
 
 
@@ -7,7 +7,7 @@ import torch
 import mdtex2html
 import gradio as gr
 
-from utils import ModelArguments, auto_configure_device_map, load_pretrained
+from utils import ModelArguments, load_pretrained
 from transformers import HfArgumentParser
 from transformers.utils.versions import require_version
 
@@ -17,8 +17,8 @@ parser = HfArgumentParser(ModelArguments)
 model_args, = parser.parse_args_into_dataclasses()
 model, tokenizer = load_pretrained(model_args)
 if torch.cuda.device_count() > 1:
-    from accelerate import dispatch_model
-    device_map = auto_configure_device_map(torch.cuda.device_count())
+    from accelerate import dispatch_model, infer_auto_device_map
+    device_map = infer_auto_device_map(model)
     model = dispatch_model(model, device_map)
 else:
     model = model.cuda()
@@ -111,7 +111,7 @@ def reset_state():
 
 
 with gr.Blocks() as demo:
-    gr.HTML("""<h1 align="center">ChatGLM-Efficient-Tuning</h1>""")
+    gr.HTML("""<h1 align="center">LLaMA-Efficient-Tuning</h1>""")
 
     chatbot = gr.Chatbot()
     with gr.Row():