|
|
@ -264,6 +264,18 @@ def prepare_args( |
|
|
|
return model_args, data_args, training_args, finetuning_args |
|
|
|
|
|
|
|
|
|
|
|
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: |
|
|
|
|
|
|
|
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) |
|
|
|
|
|
|
|
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. |
|
|
|
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
|
|
|
else: |
|
|
|
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
return model_args, data_args, finetuning_args |
|
|
|
|
|
|
|
|
|
|
|
def prepare_data( |
|
|
|
model_args: ModelArguments, |
|
|
|
data_args: DataTrainingArguments |
|
|
@ -347,7 +359,8 @@ def preprocess_data( |
|
|
|
column_names = list(dataset.column_names) |
|
|
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" |
|
|
|
|
|
|
|
def format_example(examples): # support question with a single answer or multiple answers |
|
|
|
# support question with a single answer or multiple answers |
|
|
|
def format_example_alpaca(examples): |
|
|
|
for i in range(len(examples["prompt"])): |
|
|
|
if examples["prompt"][i] and examples["response"][i]: |
|
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
@ -357,12 +370,27 @@ def preprocess_data( |
|
|
|
prompt += "Write a response that appropriately completes the request.\n" |
|
|
|
prompt += "Instruction:\n" + prefix |
|
|
|
if examples["history"][i]: |
|
|
|
history = examples["history"][i] |
|
|
|
for old_query, response in history: |
|
|
|
for old_query, response in examples["history"][i]: |
|
|
|
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) |
|
|
|
prompt += "Human: {}\nAssistant: ".format(query) |
|
|
|
yield prompt, answer |
|
|
|
|
|
|
|
def format_example_ziya(examples): |
|
|
|
for i in range(len(examples["prompt"])): |
|
|
|
if examples["prompt"][i] and examples["response"][i]: |
|
|
|
query, answer = examples["prompt"][i], examples["response"][i] |
|
|
|
if examples["query"][i]: |
|
|
|
query += "\n" + examples["query"][i] |
|
|
|
prompt = "" |
|
|
|
if examples["history"][i]: |
|
|
|
for old_query, response in examples["history"][i]: |
|
|
|
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response) |
|
|
|
prompt += "<human>: {}\n<bot>:".format(query) |
|
|
|
prompt = prefix + prompt |
|
|
|
yield prompt, answer |
|
|
|
|
|
|
|
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya |
|
|
|
|
|
|
|
def preprocess_pretrain_dataset(examples): |
|
|
|
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>) |
|
|
|
text_ids = tokenizer(examples["prompt"])["input_ids"] |
|
|
|