From 740a5daf5634f70a61b41fa8a31ee4a587fa03f3 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 31 May 2023 16:54:06 +0800 Subject: [PATCH] support BLOOM models --- README.md | 70 ++++++++++++++++++++++++++++++++++++++++++---- src/cli_demo.py | 13 +++++---- src/export_model.py | 2 +- src/train_ppo.py | 10 +++---- src/train_pt.py | 6 ++-- src/train_rm.py | 10 +++---- src/train_sft.py | 10 +++---- src/utils/__init__.py | 10 +++---- src/utils/common.py | 16 +++++------ src/utils/config.py | 24 ++++++++++------ src/utils/data_collator.py | 4 +-- src/utils/other.py | 25 +---------------- src/utils/pairwise.py | 6 ++-- src/utils/ppo.py | 4 +-- src/utils/seq2seq.py | 4 +-- src/web_demo.py | 10 +++---- 16 files changed, 134 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index ea162cb..3579c81 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,65 @@ ![GitHub last commit](https://img.shields.io/github/last-commit/hiyouga/LLaMA-Efficient-Tuning) ![GitHub pull request](https://img.shields.io/badge/PRs-welcome-blue) +## Changelog + +[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model. + +## Supported Models + +- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B) +- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B) + +## Supported Training Approach + +- [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) + - Full-parameter training + - Selected-parameter training + - [LoRA](https://arxiv.org/abs/2106.09685) +- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) + - Full-parameter training + - Selected-parameter training + - [LoRA](https://arxiv.org/abs/2106.09685) +- [RLHF](https://arxiv.org/abs/2203.02155) + - [LoRA](https://arxiv.org/abs/2106.09685) + +## Provided Datasets + +- For pre-training: + - [Wiki Demo](data/wiki_demo.txt) +- For supervised fine-tuning: + - [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) + - [Stanford Alpaca (Chinese)](https://github.com/ymcui/Chinese-LLaMA-Alpaca) + - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) + - [BELLE 2M](https://huggingface.co/datasets/BelleGroup/train_2M_CN) + - [BELLE 1M](https://huggingface.co/datasets/BelleGroup/train_1M_CN) + - [BELLE 0.5M](https://huggingface.co/datasets/BelleGroup/train_0.5M_CN) + - [BELLE Dialogue 0.4M](https://huggingface.co/datasets/BelleGroup/generated_chat_0.4M) + - [BELLE School Math 0.25M](https://huggingface.co/datasets/BelleGroup/school_math_0.25M) + - [BELLE Multiturn Chat 0.8M](https://huggingface.co/datasets/BelleGroup/multiturn_chat_0.8M) + - [Guanaco Dataset](https://huggingface.co/datasets/JosephusCheung/GuanacoDataset) + - [Firefly 1.1M](https://huggingface.co/datasets/YeungNLP/firefly-train-1.1M) + - [CodeAlpaca 20k](https://huggingface.co/datasets/sahil2801/CodeAlpaca-20k) + - [Alpaca CoT](https://huggingface.co/datasets/QingyiSi/Alpaca-CoT) + - [Web QA (Chinese)](https://huggingface.co/datasets/suolyer/webqa) + - [UltraChat](https://github.com/thunlp/UltraChat) +- For reward model training: + - [HH-RLHF](https://huggingface.co/datasets/Anthropic/hh-rlhf) + - [GPT-4 Generated Data](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) + - [GPT-4 Generated Data (Chinese)](https://github.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM) + +Please refer to [data/README.md](data/README.md) for details. + +Some datasets require confirmation before using them, so we recommend logging in with your HuggingFace account using these commands. + +```bash +pip install --upgrade huggingface_hub +huggingface-cli login +``` + ## Requirement -- Python 3.8+ and PyTorch 1.13.1 +- Python 3.8+ and PyTorch 1.13.1+ - 🤗Transformers, Datasets, Accelerate, PEFT and TRL - protobuf, cpm_kernels and sentencepiece - jieba, rouge_chinese and nltk (used at evaluation) @@ -36,10 +92,10 @@ pip install -r requirements.txt ### LLaMA Weights Preparation 1. Download the weights of the LLaMA models. -2. Convert them to HF format using this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py) +2. Convert them to HF format using the following command. -```python -python convert_llama_weights_to_hf.py \ +```bash +python -m transformers.models.llama.convert_llama_weights_to_hf \ --input_dir path_to_llama_weights --model_size 7B --output_dir path_to_llama_model ``` @@ -177,7 +233,11 @@ python src/export_model.py \ ## License -This repository is licensed under the [Apache-2.0 License](LICENSE). Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) to use the LLaMA model. +This repository is licensed under the [Apache-2.0 License](LICENSE). + +Please follow the [Model Card](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) to use the LLaMA models. + +Please follow the [RAIL License](https://huggingface.co/spaces/bigscience/license) to use the BLOOM & BLOOMZ models. ## Citation diff --git a/src/cli_demo.py b/src/cli_demo.py index 416d660..fd24e99 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -1,10 +1,10 @@ # coding=utf-8 -# Implements stream chat in command line for LLaMA fine-tuned with PEFT. +# Implements stream chat in command line for fine-tuned models. # Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint import torch -from utils import ModelArguments, auto_configure_device_map, load_pretrained +from utils import ModelArguments, load_pretrained from transformers import HfArgumentParser @@ -12,10 +12,11 @@ def main(): parser = HfArgumentParser(ModelArguments) model_args, = parser.parse_args_into_dataclasses() + model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model, tokenizer = load_pretrained(model_args) if torch.cuda.device_count() > 1: - from accelerate import dispatch_model - device_map = auto_configure_device_map(torch.cuda.device_count()) + from accelerate import dispatch_model, infer_auto_device_map + device_map = infer_auto_device_map(model) model = dispatch_model(model, device_map) else: model = model.cuda() @@ -47,7 +48,7 @@ def main(): return response, history history = [] - print("欢迎使用 LLaMA-7B 模型,输入内容即可对话,clear清空对话历史,stop终止程序") + print("欢迎使用 {} 模型,输入内容即可对话,clear清空对话历史,stop终止程序".format(model_name)) while True: try: query = input("\nInput: ") @@ -65,7 +66,7 @@ def main(): continue response, history = predict(query, history) - print("LLaMA-7B:", response) + print("{}:".format(model_name), response) if __name__ == "__main__": diff --git a/src/export_model.py b/src/export_model.py index b62f10f..9ba361c 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Exports the fine-tuned LLaMA model. +# Exports the fine-tuned model. # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model diff --git a/src/train_ppo.py b/src/train_ppo.py index f5b3c09..672dd8a 100644 --- a/src/train_ppo.py +++ b/src/train_ppo.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Implements parameter-efficient PPO training of fine-tuned LLaMA. +# Implements parameter-efficient PPO training of fine-tuned models. # This code is inspired by: # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py @@ -15,8 +15,8 @@ from utils import ( prepare_data, load_pretrained, preprocess_data, - DataCollatorForLLaMA, - PPOTrainerForLLaMA, + DynamicDataCollatorWithPadding, + PPOPeftTrainer, LogCallback, plot_loss ) @@ -29,7 +29,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="ppo") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="ppo") - data_collator = DataCollatorForLLaMA(tokenizer, model.pretrained_model) + data_collator = DynamicDataCollatorWithPadding(tokenizer, model.pretrained_model) ppo_config = PPOConfig( model_name=model_args.model_name_or_path, @@ -52,7 +52,7 @@ def main(): ) # Initialize our Trainer - ppo_trainer = PPOTrainerForLLaMA( + ppo_trainer = PPOPeftTrainer( training_args=training_args, finetuning_args=finetuning_args, callbacks=[LogCallback()], diff --git a/src/train_pt.py b/src/train_pt.py index ae4adcf..af88cb6 100644 --- a/src/train_pt.py +++ b/src/train_pt.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Implements several parameter-efficient pre-training method for LLaMA. +# Implements several parameter-efficient pre-training method. # This code is inspired by # https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/language-modeling/run_clm.py @@ -10,7 +10,7 @@ from utils import ( prepare_args, prepare_data, preprocess_data, - DataCollatorForLLaMA, + DynamicDataCollatorWithPadding, PeftTrainer, LogCallback, plot_loss @@ -24,7 +24,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="pt") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="pt") - data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss) + data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss) # Split the dataset if training_args.do_train: diff --git a/src/train_rm.py b/src/train_rm.py index 5f5afee..8b51f7b 100644 --- a/src/train_rm.py +++ b/src/train_rm.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Implements parameter-efficient training of a reward model based on LLaMA. +# Implements parameter-efficient training of reward models. # This code is inspired by: # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py @@ -10,8 +10,8 @@ from utils import ( prepare_data, load_pretrained, preprocess_data, - PairwiseDataCollatorForLLaMA, - PairwiseTrainerForLLaMA, + PairwiseDataCollatorWithPadding, + PairwisePeftTrainer, LogCallback, plot_loss ) @@ -23,7 +23,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="rm") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="rm") - data_collator = PairwiseDataCollatorForLLaMA(tokenizer, model.pretrained_model) + data_collator = PairwiseDataCollatorWithPadding(tokenizer, model.pretrained_model) training_args.remove_unused_columns = False # important for pairwise dataset @@ -38,7 +38,7 @@ def main(): trainer_kwargs = {"eval_dataset": dataset} # Initialize our Trainer - trainer = PairwiseTrainerForLLaMA( + trainer = PairwisePeftTrainer( finetuning_args=finetuning_args, model=model, args=training_args, diff --git a/src/train_sft.py b/src/train_sft.py index 16c91f0..29c593f 100644 --- a/src/train_sft.py +++ b/src/train_sft.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Implements several parameter-efficient supervised fine-tuning method for LLaMA. +# Implements several parameter-efficient supervised fine-tuning method. # This code is inspired by # https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py @@ -9,8 +9,8 @@ from utils import ( prepare_args, prepare_data, preprocess_data, - DataCollatorForLLaMA, - Seq2SeqTrainerForLLaMA, + DynamicDataCollatorWithPadding, + Seq2SeqPeftTrainer, ComputeMetrics, LogCallback, get_logits_processor, @@ -25,7 +25,7 @@ def main(): dataset = prepare_data(model_args, data_args) model, tokenizer = load_pretrained(model_args, finetuning_args, training_args.do_train, stage="sft") dataset = preprocess_data(dataset, tokenizer, data_args, training_args, stage="sft") - data_collator = DataCollatorForLLaMA(tokenizer, model, data_args.ignore_pad_token_for_loss) + data_collator = DynamicDataCollatorWithPadding(tokenizer, model, data_args.ignore_pad_token_for_loss) # Override the decoding parameters of Seq2SeqTrainer training_args.generation_max_length = training_args.generation_max_length if \ @@ -44,7 +44,7 @@ def main(): trainer_kwargs = {"eval_dataset": dataset} # Initialize our Trainer - trainer = Seq2SeqTrainerForLLaMA( + trainer = Seq2SeqPeftTrainer( finetuning_args=finetuning_args, model=model, args=training_args, diff --git a/src/utils/__init__.py b/src/utils/__init__.py index d25810c..680975a 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -5,13 +5,13 @@ from .common import ( preprocess_data ) -from .data_collator import DataCollatorForLLaMA +from .data_collator import DynamicDataCollatorWithPadding from .peft_trainer import PeftTrainer, LogCallback -from .seq2seq import ComputeMetrics, Seq2SeqTrainerForLLaMA -from .pairwise import PairwiseDataCollatorForLLaMA, PairwiseTrainerForLLaMA -from .ppo import PPOTrainerForLLaMA +from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer +from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer +from .ppo import PPOPeftTrainer from .config import ModelArguments -from .other import auto_configure_device_map, get_logits_processor, plot_loss +from .other import get_logits_processor, plot_loss diff --git a/src/utils/common.py b/src/utils/common.py index 4ba573e..9137e54 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -7,9 +7,9 @@ from typing import List, Literal, Optional, Tuple import transformers from transformers import ( - LlamaConfig, - LlamaForCausalLM, - LlamaTokenizer, + AutoConfig, + AutoModelForCausalLM, + AutoTokenizer, HfArgumentParser, Seq2SeqTrainingArguments ) @@ -151,7 +151,7 @@ def load_pretrained( assert stage in ["pt", "sft"] or finetuning_args.finetuning_type == "lora", \ "RM and PPO training can only be performed with LoRA method." - tokenizer = LlamaTokenizer.from_pretrained( + tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, use_fast=model_args.use_fast_tokenizer, padding_side="left" @@ -173,13 +173,13 @@ def load_pretrained( config_kwargs["device_map"] = "auto" # it should not be specified outside of load_in_8bit logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) - config = LlamaConfig.from_pretrained(model_args.model_name_or_path) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) # Load and prepare pretrained models (without valuehead). - model = LlamaForCausalLM.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, config=config, - torch_dtype=torch.float16, # the llama weights are float16 type + torch_dtype=torch.float16, # the model weights are float16 type **config_kwargs ) model = prepare_model_for_training(model) if is_trainable else model @@ -245,7 +245,7 @@ def prepare_args( logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") if training_args.do_train and (not training_args.fp16): - logger.warning("We recommend enable fp16 mixed precision training for LLaMA.") + logger.warning("We recommend enable fp16 mixed precision training.") if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") diff --git a/src/utils/config.py b/src/utils/config.py index 579905a..b4a7355 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -12,6 +12,12 @@ class DatasetAttr: file_name: Optional[str] = None file_sha1: Optional[str] = None + def __repr__(self) -> str: + if self.dataset_name is not None: + return self.dataset_name + else: + return self.file_name + def __post_init__(self): self.prompt_column = "instruction" self.query_column = "input" @@ -161,9 +167,11 @@ class FinetuningArguments: default=3, metadata={"help": "Number of trainable layers for Freeze fine-tuning."} ) - name_module_trainable: Optional[Literal["mlp", "qkv"]] = field( + name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( default="mlp", - metadata={"help": "Name of trainable modules for Freeze fine-tuning."} + metadata={"help": "Name of trainable modules for Freeze fine-tuning. \ + LLaMA choices: [\"mlp\", \"self_attn\"], \ + BLOOM choices: [\"mlp\", \"self_attention\"]"} ) lora_rank: Optional[int] = field( default=8, @@ -171,7 +179,7 @@ class FinetuningArguments: ) lora_alpha: Optional[float] = field( default=32.0, - metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"} + metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} ) lora_dropout: Optional[float] = field( default=0.1, @@ -179,7 +187,9 @@ class FinetuningArguments: ) lora_target: Optional[str] = field( default="q_proj,v_proj", - metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules."} + metadata={"help": "Name(s) of target modules to apply LoRA. Use comma to separate multiple modules. \ + LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"mlp\"], \ + BLOOM choices: [\"query_key_value\", \"dense\", \"mlp\"]"} ) def __post_init__(self): @@ -191,11 +201,7 @@ class FinetuningArguments: else: # fine-tuning the first n layers if num_layer_trainable < 0 trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] - if self.name_module_trainable == "mlp": - self.trainable_layers = ["layers.{:d}.mlp".format(idx) for idx in trainable_layer_ids] - elif self.name_module_trainable == "qkv": - self.trainable_layers = ["layers.{:d}.self_attn.{}".format(idx, proj) \ - for proj in ["k_proj", "q_proj", "v_proj", "o_proj"] for idx in trainable_layer_ids] + self.trainable_layers = ["layers.{:d}.{}".format(idx, self.name_module_trainable) for idx in trainable_layer_ids] assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method." diff --git a/src/utils/data_collator.py b/src/utils/data_collator.py index 2932f2b..6a6de42 100644 --- a/src/utils/data_collator.py +++ b/src/utils/data_collator.py @@ -9,9 +9,9 @@ from transformers.tokenization_utils import PreTrainedTokenizer from .other import IGNORE_INDEX -class DataCollatorForLLaMA(DataCollatorWithPadding): +class DynamicDataCollatorWithPadding(DataCollatorWithPadding): r""" - Data collator for LLaMA. It is capable of dynamically padding for batched data. + Inherits DataCollatorWithPadding. It is capable of dynamically padding for batched data. """ def __init__( self, diff --git a/src/utils/other.py b/src/utils/other.py index 8a30dd3..2008fb8 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -75,7 +75,7 @@ def prepare_model_for_training( model: PreTrainedModel, output_embedding_layer_name: Optional[str] = "lm_head", use_gradient_checkpointing: Optional[bool] = True, - layer_norm_names: Optional[List[str]] = ["norm"] # for LLaMA setting + layer_norm_names: Optional[List[str]] = ["norm", "ln_f"] # for LLaMA and BLOOM setting ) -> PreTrainedModel: for name, param in model.named_parameters(): @@ -143,29 +143,6 @@ def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) - model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) -def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: - r""" - Configures device map for LLaMA. - - Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8 - """ - num_layers = 28 - layers_per_gpu = 30 / num_gpus - device_map = {"model.embed_tokens": 0, "model.norm": 0, "lm_head": 0} - added_layers = 2 - target_gpu = 0 - - for i in range(num_layers): - if added_layers >= layers_per_gpu: - target_gpu += 1 - added_layers = 0 - assert target_gpu < num_gpus - device_map[f"model.layers.{i}"] = target_gpu - added_layers += 1 - - return device_map - - def smooth(scalars: List[float], weight: Optional[float] = 0.95) -> List[float]: """ EMA implementation according to TensorBoard. diff --git a/src/utils/pairwise.py b/src/utils/pairwise.py index 3c2aa21..317599b 100644 --- a/src/utils/pairwise.py +++ b/src/utils/pairwise.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Sequence, Union -from .data_collator import DataCollatorForLLaMA +from .data_collator import DynamicDataCollatorWithPadding from .peft_trainer import PeftTrainer @@ -10,7 +10,7 @@ from .other import get_logger logger = get_logger(__name__) -class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA): +class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding): r""" Data collator for pairwise data. """ @@ -26,7 +26,7 @@ class PairwiseDataCollatorForLLaMA(DataCollatorForLLaMA): return super().__call__(features) -class PairwiseTrainerForLLaMA(PeftTrainer): +class PairwisePeftTrainer(PeftTrainer): r""" Inherits PeftTrainer to compute pairwise loss. """ diff --git a/src/utils/ppo.py b/src/utils/ppo.py index 7a69c43..e279c19 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -58,7 +58,7 @@ def cast_layernorm_dtype( return model, layer_norm_state_dict -class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): +class PPOPeftTrainer(PPOTrainer, PeftTrainer): r""" Inherits PPOTrainer. """ @@ -130,7 +130,7 @@ class PPOTrainerForLLaMA(PPOTrainer, PeftTrainer): unwrapped_model.gradient_checkpointing_disable() unwrapped_model.config.use_cache = True - # Get response from LLaMA + # Get response from model query_tensors: torch.Tensor = batch["input_ids"] response_tensors = self.generate(batch, length_sampler=output_length_sampler, return_prompt=False, **gen_kwargs) diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index 4a48393..47bbcc1 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -22,7 +22,7 @@ logger = get_logger(__name__) @dataclass class ComputeMetrics: r""" - Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForLLaMA. + Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/0c2806fea82683349194e21996dd6b3acc3c265b/ptuning/main.py#L307 """ @@ -62,7 +62,7 @@ class ComputeMetrics: return {k: float(np.mean(v)) for k, v in score_dict.items()} -class Seq2SeqTrainerForLLaMA(PeftTrainer): +class Seq2SeqPeftTrainer(PeftTrainer): r""" Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. """ diff --git a/src/web_demo.py b/src/web_demo.py index 5bbed0b..5129ea8 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -1,5 +1,5 @@ # coding=utf-8 -# Implements user interface in browser for LLaMA fine-tuned with PEFT. +# Implements user interface in browser for fine-tuned models. # Usage: python web_demo.py --checkpoint_dir path_to_checkpoint @@ -7,7 +7,7 @@ import torch import mdtex2html import gradio as gr -from utils import ModelArguments, auto_configure_device_map, load_pretrained +from utils import ModelArguments, load_pretrained from transformers import HfArgumentParser from transformers.utils.versions import require_version @@ -17,8 +17,8 @@ parser = HfArgumentParser(ModelArguments) model_args, = parser.parse_args_into_dataclasses() model, tokenizer = load_pretrained(model_args) if torch.cuda.device_count() > 1: - from accelerate import dispatch_model - device_map = auto_configure_device_map(torch.cuda.device_count()) + from accelerate import dispatch_model, infer_auto_device_map + device_map = infer_auto_device_map(model) model = dispatch_model(model, device_map) else: model = model.cuda() @@ -111,7 +111,7 @@ def reset_state(): with gr.Blocks() as demo: - gr.HTML("""

ChatGLM-Efficient-Tuning

""") + gr.HTML("""

LLaMA-Efficient-Tuning

""") chatbot = gr.Chatbot() with gr.Row():