|
|
@ -10,8 +10,6 @@ from transformers.modeling_utils import PreTrainedModel |
|
|
|
from transformers.generation.utils import LogitsProcessorList |
|
|
|
from transformers.generation.logits_process import LogitsProcessor |
|
|
|
|
|
|
|
from accelerate.logging import get_logger |
|
|
|
|
|
|
|
from peft.utils.other import WEIGHTS_NAME |
|
|
|
|
|
|
|
|
|
|
@ -20,16 +18,19 @@ VALUE_HEAD_FILE_NAME = "value_head.bin" |
|
|
|
FINETUNING_ARGS_NAME = "finetuning_args.json" |
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
def get_logger(name: str) -> logging.Logger: |
|
|
|
return logging.getLogger(name) |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
|
level=logging.INFO, |
|
|
|
handlers=[logging.StreamHandler(sys.stdout)] |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def get_main_logger(name: str) -> logging.Logger: |
|
|
|
return get_logger(name, log_level="INFO") |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class AverageMeter: |
|
|
|