diff --git a/README.md b/README.md index d38ec04..4bdef7e 100644 --- a/README.md +++ b/README.md @@ -21,13 +21,13 @@ ## Supported Training Approaches - [(Continually) pre-training](https://s3-us-west-2.amazonaws.com/openai-assets/research-covers/language-unsupervised/language_understanding_paper.pdf) - - Full-parameter training - - Partial-parameter training + - Full-parameter tuning + - Partial-parameter tuning - [LoRA](https://arxiv.org/abs/2106.09685) - [QLoRA](https://arxiv.org/abs/2305.14314) - [Supervised fine-tuning](https://arxiv.org/abs/2109.01652) - - Full-parameter training - - Partial-parameter training + - Full-parameter tuning + - Partial-parameter tuning - [LoRA](https://arxiv.org/abs/2106.09685) - [QLoRA](https://arxiv.org/abs/2305.14314) - [RLHF](https://arxiv.org/abs/2203.02155) diff --git a/src/utils/common.py b/src/utils/common.py index b33fa15..50e3f59 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -261,8 +261,8 @@ def prepare_args( if training_args.do_predict and (not training_args.predict_with_generate): raise ValueError("Please enable `predict_with_generate` to save model predictions.") - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") + if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full": + raise ValueError("Quantization is incompatible with the full-parameter tuning.") if model_args.quantization_bit is not None and (not training_args.do_train): logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") diff --git a/src/utils/config.py b/src/utils/config.py index c8747be..ef29a17 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -148,7 +148,8 @@ class DataTrainingArguments: def __post_init__(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] - dataset_info = json.load(open(os.path.join(self.dataset_dir, "dataset_info.json"), "r")) + with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: + dataset_info = json.load(f) self.dataset_list: List[DatasetAttr] = [] for name in dataset_names: