From 4eb17bcf6c8ac51a3ec8cc5459064d1b35c82634 Mon Sep 17 00:00:00 2001
From: hiyouga <hiyouga@buaa.edu.cn>
Date: Tue, 6 Jun 2023 17:39:41 +0800
Subject: [PATCH] support distributed quantized training

---
 README.md                 |  2 +-
 src/utils/common.py       |  9 +++++----
 src/utils/other.py        | 11 ++++++-----
 src/utils/pairwise.py     |  4 ++--
 src/utils/peft_trainer.py |  4 ++--
 src/utils/ppo.py          |  4 ++--
 src/utils/seq2seq.py      |  4 ++--
 7 files changed, 20 insertions(+), 18 deletions(-)

diff --git a/README.md b/README.md
index 4bdef7e..b24b2f6 100644
--- a/README.md
+++ b/README.md
@@ -9,7 +9,7 @@
 
 ## Changelog
 
-[23/06/03] Now we support quantized training and inference (aka QLoRA). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
+[23/06/03] Now we support quantized training and inference (aka [QLoRA](https://github.com/artidoro/qlora)). Try `--quantization_bit 4/8` argument to work with quantized model. (experimental feature)
 
 [23/05/31] Now we support training the BLOOM & BLOOMZ models in this repo. Try `--model_name_or_path bigscience/bloomz-7b1-mt` argument to use the BLOOMZ model.
 
diff --git a/src/utils/common.py b/src/utils/common.py
index 50e3f59..c6653d6 100644
--- a/src/utils/common.py
+++ b/src/utils/common.py
@@ -38,7 +38,7 @@ from .config import (
 )
 
 from .other import (
-    get_logger,
+    get_main_logger,
     load_trainable_params,
     load_valuehead_params,
     print_trainable_params,
@@ -53,7 +53,7 @@ require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
 require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
 
 
-logger = get_logger(__name__)
+logger = get_main_logger(__name__)
 
 
 def _init_adapter(
@@ -190,9 +190,10 @@ def load_pretrained(
         else:
             raise NotImplementedError
         is_mergeable = False
+        config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
         logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
 
-    if model_args.quantization_bit is not None or (not is_trainable): # automatically load in CUDA
+    if not is_trainable:
         config_kwargs["device_map"] = "auto"
 
     # Load and prepare pretrained models (without valuehead).
@@ -288,7 +289,7 @@ def prepare_args(
     logger.info(
         f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
         + f"  distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
-    )
+    , main_process_only=False)
     logger.info(f"Training/evaluation parameters {training_args}")
 
     # Set seed before initializing model.
diff --git a/src/utils/other.py b/src/utils/other.py
index 77ef30b..075f324 100644
--- a/src/utils/other.py
+++ b/src/utils/other.py
@@ -10,6 +10,8 @@ from transformers.modeling_utils import PreTrainedModel
 from transformers.generation.utils import LogitsProcessorList
 from transformers.generation.logits_process import LogitsProcessor
 
+from accelerate.logging import get_logger
+
 from peft.utils.other import WEIGHTS_NAME
 
 
@@ -18,17 +20,16 @@ VALUE_HEAD_FILE_NAME = "value_head.bin"
 FINETUNING_ARGS_NAME = "finetuning_args.json"
 
 
-logger = logging.getLogger(__name__)
+logger = get_logger(__name__, log_level="INFO")
 logging.basicConfig(
     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
     datefmt="%m/%d/%Y %H:%M:%S",
-    level=logging.INFO,
     handlers=[logging.StreamHandler(sys.stdout)]
 )
 
 
-def get_logger(name: str) -> logging.Logger:
-    return logging.getLogger(name)
+def get_main_logger(name: str) -> logging.Logger:
+    return get_logger(name, log_level="INFO")
 
 
 class AverageMeter:
@@ -57,7 +58,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
     def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
         if torch.isnan(scores).any() or torch.isinf(scores).any():
             scores.zero_()
-            scores[:, 0] = 1.0
+            scores[..., 0] = 1.0
         return scores
 
 
diff --git a/src/utils/pairwise.py b/src/utils/pairwise.py
index 317599b..ee230fc 100644
--- a/src/utils/pairwise.py
+++ b/src/utils/pairwise.py
@@ -5,9 +5,9 @@ from .data_collator import DynamicDataCollatorWithPadding
 
 from .peft_trainer import PeftTrainer
 
-from .other import get_logger
+from .other import get_main_logger
 
-logger = get_logger(__name__)
+logger = get_main_logger(__name__)
 
 
 class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding):
diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py
index 160a142..eac919f 100644
--- a/src/utils/peft_trainer.py
+++ b/src/utils/peft_trainer.py
@@ -21,7 +21,7 @@ from peft.utils.other import WEIGHTS_NAME
 from .config import FinetuningArguments
 
 from .other import (
-    get_logger,
+    get_main_logger,
     get_state_dict,
     load_trainable_params,
     load_valuehead_params,
@@ -30,7 +30,7 @@ from .other import (
 )
 
 
-logger = get_logger(__name__)
+logger = get_main_logger(__name__)
 
 
 class LogCallback(TrainerCallback):
diff --git a/src/utils/ppo.py b/src/utils/ppo.py
index 701d4b4..b9191e5 100644
--- a/src/utils/ppo.py
+++ b/src/utils/ppo.py
@@ -16,12 +16,12 @@ from .config import FinetuningArguments
 
 from .other import (
     AverageMeter,
-    get_logger,
+    get_main_logger,
     get_logits_processor
 )
 
 
-logger = get_logger(__name__)
+logger = get_main_logger(__name__)
 
 
 def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py
index 47bbcc1..18ea413 100644
--- a/src/utils/seq2seq.py
+++ b/src/utils/seq2seq.py
@@ -13,10 +13,10 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
 
 from .peft_trainer import PeftTrainer
 
-from .other import get_logger, IGNORE_INDEX
+from .other import get_main_logger, IGNORE_INDEX
 
 
-logger = get_logger(__name__)
+logger = get_main_logger(__name__)
 
 
 @dataclass