|
|
@ -2,6 +2,7 @@ import os |
|
|
|
import sys |
|
|
|
import torch |
|
|
|
import hashlib |
|
|
|
from itertools import chain |
|
|
|
from typing import List, Literal, Optional, Tuple |
|
|
|
|
|
|
|
import transformers |
|
|
@ -84,6 +85,8 @@ def init_adapter( |
|
|
|
param.data = param.data.to(torch.float32) |
|
|
|
|
|
|
|
if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None: |
|
|
|
if len(model_args.checkpoint_dir) > 1: |
|
|
|
logger.warning("Only LoRA tuning accepts multiple checkpoints.") |
|
|
|
load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods |
|
|
|
|
|
|
|
if finetuning_args.finetuning_type == "lora": |
|
|
@ -117,6 +120,9 @@ def init_adapter( |
|
|
|
) |
|
|
|
model = get_peft_model(model, lora_config) |
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None: |
|
|
|
logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) |
|
|
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
|
|
@ -131,19 +137,14 @@ def load_pretrained( |
|
|
|
|
|
|
|
Support both training and inference. |
|
|
|
""" |
|
|
|
|
|
|
|
if (not is_trainable) and (model_args.checkpoint_dir is None): |
|
|
|
if finetuning_args is None: # load the fine-tuning arguments |
|
|
|
if model_args.checkpoint_dir is None: |
|
|
|
logger.warning("Checkpoint is not found at evaluation, load the original model.") |
|
|
|
finetuning_args = FinetuningArguments(finetuning_type="none") |
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None: # load fine-tuned model from checkpoint |
|
|
|
for checkpoint_dir in model_args.checkpoint_dir: |
|
|
|
if not os.path.isfile(os.path.join(checkpoint_dir, FINETUNING_ARGS_NAME)): |
|
|
|
raise ValueError("The fine-tuning arguments are not found in the provided dictionary.") |
|
|
|
logger.info("Load fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) |
|
|
|
elif os.path.exists(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)): |
|
|
|
finetuning_args = FinetuningArguments.load_from_json(os.path.join(model_args.checkpoint_dir[-1], FINETUNING_ARGS_NAME)) |
|
|
|
if finetuning_args.finetuning_type != "lora" and len(model_args.checkpoint_dir) > 1: |
|
|
|
logger.warning("Only LoRA tuning accepts multiple checkpoints.") |
|
|
|
else: |
|
|
|
raise ValueError("Missing fine-tuning arguments in the provided dictionary.") |
|
|
|
|
|
|
|
assert stage == "sft" or finetuning_args.finetuning_type == "lora", "RM and PPO training can only be performed with LoRA method." |
|
|
|
|
|
|
@ -350,7 +351,7 @@ def preprocess_data( |
|
|
|
if examples["prompt"][i] and examples["response"][i]: |
|
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
|
if examples["query"][i]: |
|
|
|
query += examples["query"][i] |
|
|
|
query += "\n" + examples["query"][i] |
|
|
|
prompt = "Below is an instruction that describes a task. " |
|
|
|
prompt += "Write a response that appropriately completes the request.\n" |
|
|
|
prompt += "Instruction:\n" + prefix |
|
|
@ -361,6 +362,20 @@ def preprocess_data( |
|
|
|
prompt += "Human: {}\nAssistant: ".format(query) |
|
|
|
yield prompt, answer |
|
|
|
|
|
|
|
def preprocess_pretrain_dataset(examples): |
|
|
|
# build grouped texts with format `<s>??` |
|
|
|
text_ids = tokenizer(examples["prompt"])["input_ids"] |
|
|
|
concatenated_ids = list(chain(*text_ids)) |
|
|
|
total_length = len(concatenated_ids) |
|
|
|
# we drop the small remainder, and if the total_length < block_size, we exclude this batch |
|
|
|
total_length = (total_length // data_args.max_source_length) * data_args.max_source_length |
|
|
|
# split by chunks of max_source_length |
|
|
|
result = [concatenated_ids[i: i+data_args.max_source_length] for i in range(0, total_length, data_args.max_source_length)] |
|
|
|
return { |
|
|
|
"input_ids": result, |
|
|
|
"labels": result.copy() |
|
|
|
} |
|
|
|
|
|
|
|
def preprocess_supervised_dataset(examples): |
|
|
|
# build inputs with format `X <s> Y </s>` and labels with format `<ignore> ... <ignore> <s> Y </s>` |
|
|
|
model_inputs = {"input_ids": [], "labels": []} |
|
|
@ -425,7 +440,9 @@ def preprocess_data( |
|
|
|
print("input_ids:\n{}".format(example["input_ids"])) |
|
|
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) |
|
|
|
print("label_ids:\n{}".format(example["labels"])) |
|
|
|
print("labels:\n{}".format(tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]]))) |
|
|
|
print("labels:\n{}".format( |
|
|
|
tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]])) |
|
|
|
) |
|
|
|
|
|
|
|
def print_pairwise_dataset_example(example): |
|
|
|
print("accept_ids:\n{}".format(example["accept_ids"])) |
|
|
@ -437,11 +454,11 @@ def preprocess_data( |
|
|
|
print("input_ids:\n{}".format(example["input_ids"])) |
|
|
|
print("inputs:\n{}".format(tokenizer.decode(example["input_ids"]))) |
|
|
|
|
|
|
|
if stage == "sft": |
|
|
|
if (not training_args.do_train) and training_args.predict_with_generate: # with generation |
|
|
|
preprocess_function = preprocess_evaluation_dataset |
|
|
|
else: # without generation |
|
|
|
preprocess_function = preprocess_supervised_dataset |
|
|
|
if stage == "pt": |
|
|
|
preprocess_function = preprocess_pretrain_dataset |
|
|
|
elif stage == "sft": |
|
|
|
preprocess_function = preprocess_evaluation_dataset \ |
|
|
|
if training_args.predict_with_generate else preprocess_supervised_dataset |
|
|
|
elif stage == "rm": |
|
|
|
preprocess_function = preprocess_pairwise_dataset |
|
|
|
elif stage == "ppo": |
|
|
|