Browse Source

support QLoRA

main
hiyouga 2 years ago
parent
commit
3b9eee8cd2
  1. 11
      README.md
  2. 44
      src/utils/common.py
  3. 5
      src/utils/config.py

11
README.md

@ -9,12 +9,14 @@
## Changelog ## Changelog
[23/06/03] Now we support quantized training and inference (aka QLoRA). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
[23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model. [23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
## Supported Models ## Supported Models
- [LLaMA](https://github.com/facebookresearch/llama) (7B, 13B, 33B, 65B) - [LLaMA](https://github.com/facebookresearch/llama) (7B/13B/33B/65B)
- [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M, 1.1B, 1.7B, 3B, 7.1B, 176B) - [BLOOM](https://huggingface.co/bigscience/bloom) & [BLOOMZ](https://huggingface.co/bigscience/bloomz) (560M/1.1B/1.7B/3B/7.1B/176B)
## Supported Training Approaches ## Supported Training Approaches
@ -22,12 +24,15 @@
- Full-parameter training - Full-parameter training
- Partial-parameter training - Partial-parameter training
- [LoRA](https://arxiv.org/abs/2106.09685) - [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) - [Supervised fine-tuning](https://arxiv.org/abs/2109.01652)
- Full-parameter training - Full-parameter training
- Partial-parameter training - Partial-parameter training
- [LoRA](https://arxiv.org/abs/2106.09685) - [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
- [RLHF](https://arxiv.org/abs/2203.02155) - [RLHF](https://arxiv.org/abs/2203.02155)
- [LoRA](https://arxiv.org/abs/2106.09685) - [LoRA](https://arxiv.org/abs/2106.09685)
- [QLoRA](https://arxiv.org/abs/2305.14314)
## Provided Datasets ## Provided Datasets
@ -209,6 +214,8 @@ CUDA_VISIBLE_DEVICES=0 python src/train_sft.py \
--predict_with_generate --predict_with_generate
``` ```
We recommend using `--per_device_eval_batch_size=1` and `--max_target_length 128` in INT8 evaluation.
### CLI Demo ### CLI Demo
```bash ```bash

44
src/utils/common.py

@ -11,7 +11,8 @@ from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
HfArgumentParser, HfArgumentParser,
Seq2SeqTrainingArguments Seq2SeqTrainingArguments,
BitsAndBytesConfig
) )
from transformers.utils import check_min_version from transformers.utils import check_min_version
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
@ -167,12 +168,27 @@ def load_pretrained(
# Quantization configurations (using bitsandbytes library). # Quantization configurations (using bitsandbytes library).
if model_args.quantization_bit is not None: if model_args.quantization_bit is not None:
assert model_args.quantization_bit == 8, "We only accept 8-bit quantization." if model_args.quantization_bit == 8:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
#require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") config_kwargs["load_in_8bit"] = True
#require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") config_kwargs["quantization_config"] = BitsAndBytesConfig(
#require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") load_in_8bit=True,
config_kwargs["load_in_8bit"] = True llm_int8_threshold=6.0
)
elif model_args.quantization_bit == 4:
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git")
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git")
require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git")
config_kwargs["load_in_4bit"] = True
config_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=finetuning_args.compute_dtype,
bnb_4bit_use_double_quant=model_args.double_quantization,
bnb_4bit_quant_type=model_args.quantization_type
)
else:
raise NotImplementedError
is_mergeable = False is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
@ -183,7 +199,7 @@ def load_pretrained(
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
config=config, config=config,
torch_dtype=torch.float16, # the model weights are float16 type torch_dtype=torch.bfloat16 if finetuning_args.compute_dtype == torch.bfloat16 else torch.float16,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
**config_kwargs **config_kwargs
) )
@ -237,13 +253,13 @@ def prepare_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
if stage != "sft" and training_args.predict_with_generate: if stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True in PT, RM and PPO stages.") raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
if training_args.do_train and training_args.predict_with_generate: if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.") raise ValueError("`predict_with_generate` cannot be set as True while training.")
if training_args.do_predict and (not training_args.predict_with_generate): if training_args.do_predict and (not training_args.predict_with_generate):
raise ValueError("Please enable `predict_with_generate` for saving model predictions.") raise ValueError("Please enable `predict_with_generate` to save model predictions.")
if model_args.quantization_bit is not None and (not training_args.do_train): if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -257,6 +273,14 @@ def prepare_args(
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
if training_args.fp16:
finetuning_args.compute_dtype = torch.float16
elif training_args.bf16:
finetuning_args.compute_dtype = torch.bfloat16
else:
finetuning_args.compute_dtype = torch.float32
# Log on each process the small summary: # Log on each process the small summary:
logger.info( logger.info(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"

5
src/utils/config.py

@ -1,5 +1,6 @@
import os import os
import json import json
import torch
from typing import List, Literal, Optional from typing import List, Literal, Optional
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
@ -207,6 +208,10 @@ class FinetuningArguments:
LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \ LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"up_proj\", \"down_proj\"], \
BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"} BLOOM choices: [\"query_key_value\", \"dense\", \"dense_\"]"}
) )
compute_dtype: Optional[torch.dtype] = field(
default=None,
metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
)
def __post_init__(self): def __post_init__(self):
if isinstance(self.lora_target, str): if isinstance(self.lora_target, str):

Loading…
Cancel
Save