From 13d1f0709c774bb5ec5fb2b4e3c66b2f1226afd2 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Tue, 6 Jun 2023 21:36:37 +0800 Subject: [PATCH] recover logging --- src/utils/common.py | 6 +++--- src/utils/other.py | 11 ++++++----- src/utils/pairwise.py | 4 ++-- src/utils/peft_trainer.py | 4 ++-- src/utils/ppo.py | 4 ++-- src/utils/seq2seq.py | 4 ++-- 6 files changed, 17 insertions(+), 16 deletions(-) diff --git a/src/utils/common.py b/src/utils/common.py index c6653d6..ffbffe4 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -38,7 +38,7 @@ from .config import ( ) from .other import ( - get_main_logger, + get_logger, load_trainable_params, load_valuehead_params, print_trainable_params, @@ -53,7 +53,7 @@ require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") -logger = get_main_logger(__name__) +logger = get_logger(__name__) def _init_adapter( @@ -289,7 +289,7 @@ def prepare_args( logger.info( f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" - , main_process_only=False) + ) logger.info(f"Training/evaluation parameters {training_args}") # Set seed before initializing model. diff --git a/src/utils/other.py b/src/utils/other.py index 075f324..c851289 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -10,8 +10,6 @@ from transformers.modeling_utils import PreTrainedModel from transformers.generation.utils import LogitsProcessorList from transformers.generation.logits_process import LogitsProcessor -from accelerate.logging import get_logger - from peft.utils.other import WEIGHTS_NAME @@ -20,16 +18,19 @@ VALUE_HEAD_FILE_NAME = "value_head.bin" FINETUNING_ARGS_NAME = "finetuning_args.json" -logger = get_logger(__name__, log_level="INFO") +def get_logger(name: str) -> logging.Logger: + return logging.getLogger(name) + + logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)] ) -def get_main_logger(name: str) -> logging.Logger: - return get_logger(name, log_level="INFO") +logger = get_logger(__name__) class AverageMeter: diff --git a/src/utils/pairwise.py b/src/utils/pairwise.py index ee230fc..317599b 100644 --- a/src/utils/pairwise.py +++ b/src/utils/pairwise.py @@ -5,9 +5,9 @@ from .data_collator import DynamicDataCollatorWithPadding from .peft_trainer import PeftTrainer -from .other import get_main_logger +from .other import get_logger -logger = get_main_logger(__name__) +logger = get_logger(__name__) class PairwiseDataCollatorWithPadding(DynamicDataCollatorWithPadding): diff --git a/src/utils/peft_trainer.py b/src/utils/peft_trainer.py index eac919f..160a142 100644 --- a/src/utils/peft_trainer.py +++ b/src/utils/peft_trainer.py @@ -21,7 +21,7 @@ from peft.utils.other import WEIGHTS_NAME from .config import FinetuningArguments from .other import ( - get_main_logger, + get_logger, get_state_dict, load_trainable_params, load_valuehead_params, @@ -30,7 +30,7 @@ from .other import ( ) -logger = get_main_logger(__name__) +logger = get_logger(__name__) class LogCallback(TrainerCallback): diff --git a/src/utils/ppo.py b/src/utils/ppo.py index b9191e5..701d4b4 100644 --- a/src/utils/ppo.py +++ b/src/utils/ppo.py @@ -16,12 +16,12 @@ from .config import FinetuningArguments from .other import ( AverageMeter, - get_main_logger, + get_logger, get_logits_processor ) -logger = get_main_logger(__name__) +logger = get_logger(__name__) def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: diff --git a/src/utils/seq2seq.py b/src/utils/seq2seq.py index 18ea413..47bbcc1 100644 --- a/src/utils/seq2seq.py +++ b/src/utils/seq2seq.py @@ -13,10 +13,10 @@ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction from .peft_trainer import PeftTrainer -from .other import get_main_logger, IGNORE_INDEX +from .other import get_logger, IGNORE_INDEX -logger = get_main_logger(__name__) +logger = get_logger(__name__) @dataclass