|
@ -29,6 +29,8 @@ from peft import ( |
|
|
get_peft_model |
|
|
get_peft_model |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
from peft.utils import CONFIG_NAME |
|
|
|
|
|
|
|
|
from trl import AutoModelForCausalLMWithValueHead |
|
|
from trl import AutoModelForCausalLMWithValueHead |
|
|
|
|
|
|
|
|
from .config import ( |
|
|
from .config import ( |
|
@ -37,10 +39,7 @@ from .config import ( |
|
|
FinetuningArguments |
|
|
FinetuningArguments |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
from .template import ( |
|
|
from .template import Template |
|
|
prompt_template_alpaca, |
|
|
|
|
|
prompt_template_ziya |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
from .other import ( |
|
|
from .other import ( |
|
|
get_logger, |
|
|
get_logger, |
|
@ -102,6 +101,9 @@ def _init_adapter( |
|
|
logger.info("Fine-tuning method: LoRA") |
|
|
logger.info("Fine-tuning method: LoRA") |
|
|
lastest_checkpoint = None |
|
|
lastest_checkpoint = None |
|
|
|
|
|
|
|
|
|
|
|
assert os.path.exists(model_args.checkpoint_dir[0], CONFIG_NAME), \ |
|
|
|
|
|
"The given checkpoint is not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." |
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None: |
|
|
if model_args.checkpoint_dir is not None: |
|
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights |
|
|
if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights |
|
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] |
|
|
checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] |
|
@ -401,7 +403,7 @@ def preprocess_data( |
|
|
|
|
|
|
|
|
column_names = list(dataset.column_names) |
|
|
column_names = list(dataset.column_names) |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya |
|
|
prompt_template = Template(data_args.prompt_template) |
|
|
|
|
|
|
|
|
# support question with a single answer or multiple answers |
|
|
# support question with a single answer or multiple answers |
|
|
def format_example(examples): |
|
|
def format_example(examples): |
|
@ -410,8 +412,7 @@ def preprocess_data( |
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
if examples["query"][i]: |
|
|
if examples["query"][i]: |
|
|
query += "\n" + examples["query"][i] |
|
|
query += "\n" + examples["query"][i] |
|
|
prompt = prompt_template(query, examples["history"][i]) |
|
|
prompt = prompt_template.get_prompt(query, examples["history"][i], prefix) |
|
|
prompt = prefix + prompt |
|
|
|
|
|
yield prompt, answer |
|
|
yield prompt, answer |
|
|
|
|
|
|
|
|
def preprocess_pretrain_dataset(examples): |
|
|
def preprocess_pretrain_dataset(examples): |
|
|