|
@ -56,7 +56,6 @@ require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") |
|
|
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") |
|
|
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") |
|
|
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") |
|
|
require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) |
|
|
logger = get_logger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -92,10 +91,12 @@ def _init_adapter( |
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None: |
|
|
if model_args.checkpoint_dir is not None: |
|
|
if finetuning_args.finetuning_type != "lora": |
|
|
if finetuning_args.finetuning_type != "lora": |
|
|
assert is_mergeable and len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." |
|
|
assert is_mergeable and len( |
|
|
|
|
|
model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." |
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods |
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods |
|
|
else: |
|
|
else: |
|
|
assert is_mergeable or len(model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." |
|
|
assert is_mergeable or len( |
|
|
|
|
|
model_args.checkpoint_dir) == 1, "Quantized model only accepts a single checkpoint." |
|
|
|
|
|
|
|
|
if finetuning_args.finetuning_type == "lora": |
|
|
if finetuning_args.finetuning_type == "lora": |
|
|
logger.info("Fine-tuning method: LoRA") |
|
|
logger.info("Fine-tuning method: LoRA") |
|
@ -105,7 +106,8 @@ def _init_adapter( |
|
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ |
|
|
assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ |
|
|
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." |
|
|
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." |
|
|
|
|
|
|
|
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights |
|
|
if (is_trainable and model_args.resume_lora_training) or ( |
|
|
|
|
|
not is_mergeable): # continually train on the lora weights |
|
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] |
|
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] |
|
|
else: |
|
|
else: |
|
|
checkpoints_to_merge = model_args.checkpoint_dir |
|
|
checkpoints_to_merge = model_args.checkpoint_dir |
|
@ -184,9 +186,11 @@ def load_pretrained( |
|
|
) |
|
|
) |
|
|
elif model_args.quantization_bit == 4: |
|
|
elif model_args.quantization_bit == 4: |
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") |
|
|
require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") |
|
|
require_version("transformers>=4.30.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git") |
|
|
require_version("transformers>=4.30.0.dev0", |
|
|
|
|
|
"To fix: pip install git+https://github.com/huggingface/transformers.git") |
|
|
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") |
|
|
require_version("peft>=0.4.0.dev0", "To fix: pip install git+https://github.com/huggingface/peft.git") |
|
|
require_version("accelerate>=0.20.0.dev0", "To fix: pip install git+https://github.com/huggingface/accelerate.git") |
|
|
require_version("accelerate>=0.20.0.dev0", |
|
|
|
|
|
"To fix: pip install git+https://github.com/huggingface/accelerate.git") |
|
|
config_kwargs["load_in_4bit"] = True |
|
|
config_kwargs["load_in_4bit"] = True |
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
config_kwargs["quantization_config"] = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
load_in_4bit=True, |
|
@ -241,11 +245,11 @@ def load_pretrained( |
|
|
def prepare_args( |
|
|
def prepare_args( |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: |
|
|
) -> Tuple[ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments]: |
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) |
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments, FinetuningArguments)) |
|
|
|
|
|
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_json_file( |
|
|
|
|
|
json_file=os.path.abspath(sys.argv[1])) |
|
|
else: |
|
|
else: |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
model_args, data_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
@ -310,7 +314,6 @@ def prepare_args( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: |
|
|
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: |
|
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) |
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) |
|
|
|
|
|
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
@ -331,7 +334,6 @@ def prepare_data( |
|
|
model_args: ModelArguments, |
|
|
model_args: ModelArguments, |
|
|
data_args: DataTrainingArguments |
|
|
data_args: DataTrainingArguments |
|
|
) -> Dataset: |
|
|
) -> Dataset: |
|
|
|
|
|
|
|
|
def checksum(file_path, hash): |
|
|
def checksum(file_path, hash): |
|
|
with open(file_path, "rb") as datafile: |
|
|
with open(file_path, "rb") as datafile: |
|
|
binary_data = datafile.read() |
|
|
binary_data = datafile.read() |
|
@ -361,7 +363,7 @@ def prepare_data( |
|
|
checksum(data_file, dataset_attr.file_sha1) |
|
|
checksum(data_file, dataset_attr.file_sha1) |
|
|
else: |
|
|
else: |
|
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") |
|
|
logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json.") |
|
|
|
|
|
print(extension) |
|
|
raw_datasets = load_dataset( |
|
|
raw_datasets = load_dataset( |
|
|
extension if extension in ["csv", "json"] else "text", |
|
|
extension if extension in ["csv", "json"] else "text", |
|
|
data_files=data_file, |
|
|
data_files=data_file, |
|
@ -406,7 +408,6 @@ def preprocess_data( |
|
|
training_args: Seq2SeqTrainingArguments, |
|
|
training_args: Seq2SeqTrainingArguments, |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
stage: Literal["pt", "sft", "rm", "ppo"] |
|
|
) -> Dataset: |
|
|
) -> Dataset: |
|
|
|
|
|
|
|
|
column_names = list(dataset.column_names) |
|
|
column_names = list(dataset.column_names) |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
prompt_template = Template(data_args.prompt_template) |
|
|
prompt_template = Template(data_args.prompt_template) |
|
@ -429,7 +430,8 @@ def preprocess_data( |
|
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch |
|
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch |
|
|
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length |
|
|
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length |
|
|
# split by chunks of max_source_length |
|
|
# split by chunks of max_source_length |
|
|
result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)] |
|
|
result = [concatenated_ids[i: i + data_args.max_source_length] for i in |
|
|
|
|
|
range(0, total_length, data_args.max_source_length)] |
|
|
return { |
|
|
return { |
|
|
"input_ids": result, |
|
|
"input_ids": result, |
|
|
"labels": result.copy() |
|
|
"labels": result.copy() |
|
|