|
@ -37,6 +37,11 @@ from .config import ( |
|
|
FinetuningArguments |
|
|
FinetuningArguments |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
from .template import ( |
|
|
|
|
|
prompt_template_alpaca, |
|
|
|
|
|
prompt_template_ziya |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
from .other import ( |
|
|
from .other import ( |
|
|
get_logger, |
|
|
get_logger, |
|
|
load_trainable_params, |
|
|
load_trainable_params, |
|
@ -224,6 +229,7 @@ def load_pretrained( |
|
|
|
|
|
|
|
|
if not is_trainable: |
|
|
if not is_trainable: |
|
|
model.requires_grad_(False) # fix all model params |
|
|
model.requires_grad_(False) # fix all model params |
|
|
|
|
|
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 |
|
|
|
|
|
|
|
|
print_trainable_params(model) |
|
|
print_trainable_params(model) |
|
|
|
|
|
|
|
@ -395,39 +401,19 @@ def preprocess_data( |
|
|
|
|
|
|
|
|
column_names = list(dataset.column_names) |
|
|
column_names = list(dataset.column_names) |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
|
|
|
prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya |
|
|
|
|
|
|
|
|
# support question with a single answer or multiple answers |
|
|
# support question with a single answer or multiple answers |
|
|
def format_example_alpaca(examples): |
|
|
def format_example(examples): |
|
|
for i in range(len(examples["prompt"])): |
|
|
|
|
|
if examples["prompt"][i] and examples["response"][i]: |
|
|
|
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
|
|
|
if examples["query"][i]: |
|
|
|
|
|
query += "\n" + examples["query"][i] |
|
|
|
|
|
prompt = "Below is an instruction that describes a task. " |
|
|
|
|
|
prompt += "Write a response that appropriately completes the request.\n" |
|
|
|
|
|
prompt += "Instruction:\n" + prefix |
|
|
|
|
|
if examples["history"][i]: |
|
|
|
|
|
for old_query, response in examples["history"][i]: |
|
|
|
|
|
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) |
|
|
|
|
|
prompt += "Human: {}\nAssistant: ".format(query) |
|
|
|
|
|
yield prompt, answer |
|
|
|
|
|
|
|
|
|
|
|
def format_example_ziya(examples): |
|
|
|
|
|
for i in range(len(examples["prompt"])): |
|
|
for i in range(len(examples["prompt"])): |
|
|
if examples["prompt"][i] and examples["response"][i]: |
|
|
if examples["prompt"][i] and examples["response"][i]: |
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
if examples["query"][i]: |
|
|
if examples["query"][i]: |
|
|
query += "\n" + examples["query"][i] |
|
|
query += "\n" + examples["query"][i] |
|
|
prompt = "" |
|
|
prompt = prompt_template(query, examples["history"][i]) |
|
|
if examples["history"][i]: |
|
|
|
|
|
for old_query, response in examples["history"][i]: |
|
|
|
|
|
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response) |
|
|
|
|
|
prompt += "<human>: {}\n<bot>:".format(query) |
|
|
|
|
|
prompt = prefix + prompt |
|
|
prompt = prefix + prompt |
|
|
yield prompt, answer |
|
|
yield prompt, answer |
|
|
|
|
|
|
|
|
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_pretrain_dataset(examples): |
|
|
def preprocess_pretrain_dataset(examples): |
|
|
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>) |
|
|
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>) |
|
|
text_ids = tokenizer(examples["prompt"])["input_ids"] |
|
|
text_ids = tokenizer(examples["prompt"])["input_ids"] |
|
|