diff --git a/src/cli_demo.py b/src/cli_demo.py index 4426200..6ee12ce 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -4,14 +4,16 @@ import torch -from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor -from transformers import HfArgumentParser +from utils import ( + load_pretrained, + prepare_infer_args, + get_logits_processor +) def main(): - parser = HfArgumentParser((ModelArguments, FinetuningArguments)) - model_args, finetuning_args = parser.parse_args_into_dataclasses() + model_args, data_args, finetuning_args = prepare_infer_args() model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model, tokenizer = load_pretrained(model_args, finetuning_args) @@ -24,14 +26,26 @@ def main(): model.eval() - def format_example(query): + def format_example_alpaca(query, history): prompt = "Below is an instruction that describes a task. " prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) + prompt += "Instruction:\n" + for old_query, response in history: + prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) + prompt += "Human: {}\nAssistant:".format(query) return prompt + def format_example_ziya(query, history): + prompt = "" + for old_query, response in history: + prompt += ": {}\n: {}\n".format(old_query, response) + prompt += ": {}\n:".format(query) + return prompt + + format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya + def predict(query, history: list): - input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"] + input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) gen_kwargs = { "do_sample": True, @@ -65,6 +79,7 @@ def main(): if query.strip() == "clear": history = [] + print("History has been removed.") continue response, history = predict(query, history) diff --git a/src/export_model.py b/src/export_model.py index f202401..7198518 100644 --- a/src/export_model.py +++ b/src/export_model.py @@ -3,19 +3,15 @@ # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model -from transformers import HfArgumentParser, TrainingArguments -from utils import ModelArguments, FinetuningArguments, load_pretrained +from utils import load_pretrained, prepare_args def main(): - parser = HfArgumentParser((ModelArguments, TrainingArguments, FinetuningArguments)) - model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses() - + model_args, _, training_args, finetuning_args = prepare_args(stage="sft") model, tokenizer = load_pretrained(model_args, finetuning_args) model.save_pretrained(training_args.output_dir, max_shard_size="10GB") tokenizer.save_pretrained(training_args.output_dir) - print("model and tokenizer have been saved at:", training_args.output_dir) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 9e536b8..69c69c4 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -1,6 +1,7 @@ from .common import ( load_pretrained, prepare_args, + prepare_infer_args, prepare_data, preprocess_data ) @@ -13,5 +14,4 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .ppo import PPOPeftTrainer -from .config import ModelArguments, FinetuningArguments from .other import get_logits_processor, plot_loss diff --git a/src/utils/common.py b/src/utils/common.py index 66fa4c3..396c91d 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -264,6 +264,18 @@ def prepare_args( return model_args, data_args, training_args, finetuning_args +def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]: + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments)) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file. + model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses() + + return model_args, data_args, finetuning_args + + def prepare_data( model_args: ModelArguments, data_args: DataTrainingArguments @@ -347,7 +359,8 @@ def preprocess_data( column_names = list(dataset.column_names) prefix = data_args.source_prefix if data_args.source_prefix is not None else "" - def format_example(examples): # support question with a single answer or multiple answers + # support question with a single answer or multiple answers + def format_example_alpaca(examples): for i in range(len(examples["prompt"])): if examples["prompt"][i] and examples["response"][i]: query, answer = examples["prompt"][i], examples["response"][i] @@ -357,12 +370,27 @@ def preprocess_data( prompt += "Write a response that appropriately completes the request.\n" prompt += "Instruction:\n" + prefix if examples["history"][i]: - history = examples["history"][i] - for old_query, response in history: + for old_query, response in examples["history"][i]: prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) prompt += "Human: {}\nAssistant: ".format(query) yield prompt, answer + def format_example_ziya(examples): + for i in range(len(examples["prompt"])): + if examples["prompt"][i] and examples["response"][i]: + query, answer = examples["prompt"][i], examples["response"][i] + if examples["query"][i]: + query += "\n" + examples["query"][i] + prompt = "" + if examples["history"][i]: + for old_query, response in examples["history"][i]: + prompt += ": {}\n: {}\n".format(old_query, response) + prompt += ": {}\n:".format(query) + prompt = prefix + prompt + yield prompt, answer + + format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya + def preprocess_pretrain_dataset(examples): # build grouped texts with format ` X1 X2 X3 ...` (without ) text_ids = tokenizer(examples["prompt"])["input_ids"] diff --git a/src/utils/config.py b/src/utils/config.py index deb13ad..3d9af49 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -136,6 +136,10 @@ class DataTrainingArguments: default=0, metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} ) + prompt_template: Optional[Literal["alpaca", "ziya"]] = field( + default="alpaca", + metadata={"help": "Which template to use for constructing prompts in training."} + ) def __post_init__(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] diff --git a/src/web_demo.py b/src/web_demo.py index 7445d0e..426fe52 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -7,14 +7,12 @@ import torch import mdtex2html import gradio as gr -from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor -from transformers import HfArgumentParser +from utils import load_pretrained, prepare_infer_args, get_logits_processor from transformers.utils.versions import require_version require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems -parser = HfArgumentParser((ModelArguments, FinetuningArguments)) -model_args, finetuning_args = parser.parse_args_into_dataclasses() +model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) if torch.cuda.device_count() > 1: @@ -75,17 +73,31 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT return text -def format_example(query): +def format_example_alpaca(query, history): prompt = "Below is an instruction that describes a task. " prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) + prompt += "Instruction:\n" + for old_query, response in history: + prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) + prompt += "Human: {}\nAssistant:".format(query) return prompt +def format_example_ziya(query, history): + prompt = "" + for old_query, response in history: + prompt += ": {}\n: {}\n".format(old_query, response) + prompt += ": {}\n:".format(query) + return prompt + + +format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya + + def predict(input, chatbot, max_length, top_p, temperature, history): chatbot.append((parse_text(input), "")) - input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"] + input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) gen_kwargs = { "do_sample": True,