diff --git a/src/utils/config.py b/src/utils/config.py index e7a7ae2..ccf3c73 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -141,9 +141,9 @@ class DataTrainingArguments: default=0, metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} ) - prompt_template: Optional[Literal["alpaca", "vicuna", "ziya"]] = field( + prompt_template: Optional[str] = field( default="alpaca", - metadata={"help": "Which template to use for constructing prompts in training."} + metadata={"help": "Which template to use for constructing prompts in training and inference."} ) def __post_init__(self): # support mixing multiple datasets diff --git a/src/utils/template.py b/src/utils/template.py index 4239182..b778623 100644 --- a/src/utils/template.py +++ b/src/utils/template.py @@ -7,20 +7,31 @@ class Template: name: str + def __post_init__(self): + assert hasattr(self, "_format_{}".format(self.name)), "Template {} does not exist.".format(self.name) + def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str: return getattr(self, "_format_{}".format(self.name))(query, history, prefix) + def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: + prompt = prefix + if history: + for old_query, response in history: + prompt += old_query + "\n" + response + "\n" + prompt += query + return prompt + def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: if prefix: prompt = prefix else: prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n" + prompt += "Write a response that appropriately completes the request.\n\n" prompt += "Instruction:\n" if history: for old_query, response in history: - prompt += "Human:{}\nAssistant:{}\n".format(old_query, response) - prompt += "Human:{}\nAssistant:".format(query) + prompt += "Human:\n{}\n\nAssistant:\n{}\n\n".format(old_query, response) + prompt += "Human:\n{}\n\nAssistant:".format(query) return prompt def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: @@ -32,9 +43,16 @@ class Template: if history: for old_query, response in history: prompt += "USER: {} ASSISTANT: {}".format(old_query, response) - prompt += "USER: {} ASSISTANT:".format(query) + prompt += "USER: {} ASSISTANT: ".format(query) return prompt + def _format_belle(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: + prompt = prefix + if history: + for old_query, response in history: + prompt += "Human: {}\n\nBelle: {}\n\n".format(old_query, response) + prompt += "Human: {}\n\nBelle: ".format(query) + return prompt def _format_ziya(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: prompt = prefix