|
@ -7,20 +7,31 @@ class Template: |
|
|
|
|
|
|
|
|
name: str |
|
|
name: str |
|
|
|
|
|
|
|
|
|
|
|
def __post_init__(self): |
|
|
|
|
|
assert hasattr(self, "_format_{}".format(self.name)), "Template {} does not exist.".format(self.name) |
|
|
|
|
|
|
|
|
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str: |
|
|
def get_prompt(self, query: str, history: Optional[list] = None, prefix: Optional[str] = "") -> str: |
|
|
return getattr(self, "_format_{}".format(self.name))(query, history, prefix) |
|
|
return getattr(self, "_format_{}".format(self.name))(query, history, prefix) |
|
|
|
|
|
|
|
|
|
|
|
def _format_vanilla(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
|
|
|
|
prompt = prefix |
|
|
|
|
|
if history: |
|
|
|
|
|
for old_query, response in history: |
|
|
|
|
|
prompt += old_query + "\n" + response + "\n" |
|
|
|
|
|
prompt += query |
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
|
def _format_alpaca(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
|
if prefix: |
|
|
if prefix: |
|
|
prompt = prefix |
|
|
prompt = prefix |
|
|
else: |
|
|
else: |
|
|
prompt = "Below is an instruction that describes a task. " |
|
|
prompt = "Below is an instruction that describes a task. " |
|
|
prompt += "Write a response that appropriately completes the request.\n" |
|
|
prompt += "Write a response that appropriately completes the request.\n\n" |
|
|
prompt += "Instruction:\n" |
|
|
prompt += "Instruction:\n" |
|
|
if history: |
|
|
if history: |
|
|
for old_query, response in history: |
|
|
for old_query, response in history: |
|
|
prompt += "Human:{}\nAssistant:{}\n".format(old_query, response) |
|
|
prompt += "Human:\n{}\n\nAssistant:\n{}\n\n".format(old_query, response) |
|
|
prompt += "Human:{}\nAssistant:".format(query) |
|
|
prompt += "Human:\n{}\n\nAssistant:".format(query) |
|
|
return prompt |
|
|
return prompt |
|
|
|
|
|
|
|
|
def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
|
def _format_vicuna(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
@ -32,9 +43,16 @@ class Template: |
|
|
if history: |
|
|
if history: |
|
|
for old_query, response in history: |
|
|
for old_query, response in history: |
|
|
prompt += "USER: {} ASSISTANT: {}</s>".format(old_query, response) |
|
|
prompt += "USER: {} ASSISTANT: {}</s>".format(old_query, response) |
|
|
prompt += "USER: {} ASSISTANT:".format(query) |
|
|
prompt += "USER: {} ASSISTANT: ".format(query) |
|
|
return prompt |
|
|
return prompt |
|
|
|
|
|
|
|
|
|
|
|
def _format_belle(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
|
|
|
|
prompt = prefix |
|
|
|
|
|
if history: |
|
|
|
|
|
for old_query, response in history: |
|
|
|
|
|
prompt += "Human: {}\n\nBelle: {}\n\n".format(old_query, response) |
|
|
|
|
|
prompt += "Human: {}\n\nBelle: ".format(query) |
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
def _format_ziya(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
|
def _format_ziya(self, query: str, history: Optional[list], prefix: Optional[str] = "") -> str: |
|
|
prompt = prefix |
|
|
prompt = prefix |
|
|