|
@ -195,8 +195,6 @@ def load_pretrained( |
|
|
bnb_4bit_use_double_quant=model_args.double_quantization, |
|
|
bnb_4bit_use_double_quant=model_args.double_quantization, |
|
|
bnb_4bit_quant_type=model_args.quantization_type |
|
|
bnb_4bit_quant_type=model_args.quantization_type |
|
|
) |
|
|
) |
|
|
else: |
|
|
|
|
|
raise NotImplementedError |
|
|
|
|
|
is_mergeable = False |
|
|
is_mergeable = False |
|
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} |
|
|
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} |
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) |
|
|
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit)) |
|
@ -273,8 +271,8 @@ def prepare_args( |
|
|
if training_args.do_predict and (not training_args.predict_with_generate): |
|
|
if training_args.do_predict and (not training_args.predict_with_generate): |
|
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.") |
|
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.") |
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type == "full": |
|
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": |
|
|
raise ValueError("Quantization is incompatible with the full-parameter tuning.") |
|
|
raise ValueError("Quantization is only compatible with the LoRA method.") |
|
|
|
|
|
|
|
|
if model_args.quantization_bit is not None and (not training_args.do_train): |
|
|
if model_args.quantization_bit is not None and (not training_args.do_train): |
|
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") |
|
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") |
|
@ -358,7 +356,14 @@ def prepare_data( |
|
|
) |
|
|
) |
|
|
elif dataset_attr.load_from == "file": |
|
|
elif dataset_attr.load_from == "file": |
|
|
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) |
|
|
data_file = os.path.join(data_args.dataset_dir, dataset_attr.file_name) |
|
|
|
|
|
|
|
|
extension = dataset_attr.file_name.split(".")[-1] |
|
|
extension = dataset_attr.file_name.split(".")[-1] |
|
|
|
|
|
if extension == "csv": |
|
|
|
|
|
file_type = "csv" |
|
|
|
|
|
elif extension == "json" or extension == "jsonl": |
|
|
|
|
|
file_type = "json" |
|
|
|
|
|
else: |
|
|
|
|
|
file_type = "text" |
|
|
|
|
|
|
|
|
if dataset_attr.file_sha1 is not None: |
|
|
if dataset_attr.file_sha1 is not None: |
|
|
checksum(data_file, dataset_attr.file_sha1) |
|
|
checksum(data_file, dataset_attr.file_sha1) |
|
@ -366,7 +371,7 @@ def prepare_data( |
|
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") |
|
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") |
|
|
|
|
|
|
|
|
raw_datasets = load_dataset( |
|
|
raw_datasets = load_dataset( |
|
|
extension if extension in ["csv", "json"] else "text", |
|
|
file_type, |
|
|
data_files=data_file, |
|
|
data_files=data_file, |
|
|
cache_dir=model_args.cache_dir, |
|
|
cache_dir=model_args.cache_dir, |
|
|
use_auth_token=True if model_args.use_auth_token else None |
|
|
use_auth_token=True if model_args.use_auth_token else None |
|
|