Browse Source

add ziya prompt template

main
hiyouga 2 years ago
parent
commit
de09ee1315
  1. 29
      src/cli_demo.py
  2. 8
      src/export_model.py
  3. 2
      src/utils/__init__.py
  4. 34
      src/utils/common.py
  5. 4
      src/utils/config.py
  6. 26
      src/web_demo.py

29
src/cli_demo.py

@ -4,14 +4,16 @@
import torch import torch
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor from utils import (
from transformers import HfArgumentParser load_pretrained,
prepare_infer_args,
get_logits_processor
)
def main(): def main():
parser = HfArgumentParser((ModelArguments, FinetuningArguments)) model_args, data_args, finetuning_args = prepare_infer_args()
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA"
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
@ -24,14 +26,26 @@ def main():
model.eval() model.eval()
def format_example(query): def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. " prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n" prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) prompt += "Instruction:\n"
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant:".format(query)
return prompt return prompt
def format_example_ziya(query, history):
prompt = ""
for old_query, response in history:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
return prompt
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def predict(query, history: list): def predict(query, history: list):
input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"] input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
gen_kwargs = { gen_kwargs = {
"do_sample": True, "do_sample": True,
@ -65,6 +79,7 @@ def main():
if query.strip() == "clear": if query.strip() == "clear":
history = [] history = []
print("History has been removed.")
continue continue
response, history = predict(query, history) response, history = predict(query, history)

8
src/export_model.py

@ -3,19 +3,15 @@
# Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model # Usage: python export_model.py --checkpoint_dir path_to_checkpoint --output_dir path_to_save_model
from transformers import HfArgumentParser, TrainingArguments from utils import load_pretrained, prepare_args
from utils import ModelArguments, FinetuningArguments, load_pretrained
def main(): def main():
parser = HfArgumentParser((ModelArguments, TrainingArguments, FinetuningArguments)) model_args, _, training_args, finetuning_args = prepare_args(stage="sft")
model_args, training_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
model.save_pretrained(training_args.output_dir, max_shard_size="10GB") model.save_pretrained(training_args.output_dir, max_shard_size="10GB")
tokenizer.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir)
print("model and tokenizer have been saved at:", training_args.output_dir) print("model and tokenizer have been saved at:", training_args.output_dir)

2
src/utils/__init__.py

@ -1,6 +1,7 @@
from .common import ( from .common import (
load_pretrained, load_pretrained,
prepare_args, prepare_args,
prepare_infer_args,
prepare_data, prepare_data,
preprocess_data preprocess_data
) )
@ -13,5 +14,4 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer
from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer
from .ppo import PPOPeftTrainer from .ppo import PPOPeftTrainer
from .config import ModelArguments, FinetuningArguments
from .other import get_logits_processor, plot_loss from .other import get_logits_processor, plot_loss

34
src/utils/common.py

@ -264,6 +264,18 @@ def prepare_args(
return model_args, data_args, training_args, finetuning_args return model_args, data_args, training_args, finetuning_args
def prepare_infer_args() -> Tuple[ModelArguments, DataTrainingArguments, FinetuningArguments]:
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FinetuningArguments))
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): # Provide arguments with a json file.
model_args, data_args, finetuning_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
else:
model_args, data_args, finetuning_args = parser.parse_args_into_dataclasses()
return model_args, data_args, finetuning_args
def prepare_data( def prepare_data(
model_args: ModelArguments, model_args: ModelArguments,
data_args: DataTrainingArguments data_args: DataTrainingArguments
@ -347,7 +359,8 @@ def preprocess_data(
column_names = list(dataset.column_names) column_names = list(dataset.column_names)
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
def format_example(examples): # support question with a single answer or multiple answers # support question with a single answer or multiple answers
def format_example_alpaca(examples):
for i in range(len(examples["prompt"])): for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]: if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i] query, answer = examples["prompt"][i], examples["response"][i]
@ -357,12 +370,27 @@ def preprocess_data(
prompt += "Write a response that appropriately completes the request.\n" prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\n" + prefix prompt += "Instruction:\n" + prefix
if examples["history"][i]: if examples["history"][i]:
history = examples["history"][i] for old_query, response in examples["history"][i]:
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant: ".format(query) prompt += "Human: {}\nAssistant: ".format(query)
yield prompt, answer yield prompt, answer
def format_example_ziya(examples):
for i in range(len(examples["prompt"])):
if examples["prompt"][i] and examples["response"][i]:
query, answer = examples["prompt"][i], examples["response"][i]
if examples["query"][i]:
query += "\n" + examples["query"][i]
prompt = ""
if examples["history"][i]:
for old_query, response in examples["history"][i]:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
prompt = prefix + prompt
yield prompt, answer
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def preprocess_pretrain_dataset(examples): def preprocess_pretrain_dataset(examples):
# build grouped texts with format `<s> X1 X2 X3 ...` (without </s>) # build grouped texts with format `<s> X1 X2 X3 ...` (without </s>)
text_ids = tokenizer(examples["prompt"])["input_ids"] text_ids = tokenizer(examples["prompt"])["input_ids"]

4
src/utils/config.py

@ -136,6 +136,10 @@ class DataTrainingArguments:
default=0, default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
) )
prompt_template: Optional[Literal["alpaca", "ziya"]] = field(
default="alpaca",
metadata={"help": "Which template to use for constructing prompts in training."}
)
def __post_init__(self): # support mixing multiple datasets def __post_init__(self): # support mixing multiple datasets
dataset_names = [ds.strip() for ds in self.dataset.split(",")] dataset_names = [ds.strip() for ds in self.dataset.split(",")]

26
src/web_demo.py

@ -7,14 +7,12 @@ import torch
import mdtex2html import mdtex2html
import gradio as gr import gradio as gr
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor from utils import load_pretrained, prepare_infer_args, get_logits_processor
from transformers import HfArgumentParser
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems
parser = HfArgumentParser((ModelArguments, FinetuningArguments)) model_args, data_args, finetuning_args = prepare_infer_args()
model_args, finetuning_args = parser.parse_args_into_dataclasses()
model, tokenizer = load_pretrained(model_args, finetuning_args) model, tokenizer = load_pretrained(model_args, finetuning_args)
if torch.cuda.device_count() > 1: if torch.cuda.device_count() > 1:
@ -75,17 +73,31 @@ def parse_text(text): # copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT
return text return text
def format_example(query): def format_example_alpaca(query, history):
prompt = "Below is an instruction that describes a task. " prompt = "Below is an instruction that describes a task. "
prompt += "Write a response that appropriately completes the request.\n" prompt += "Write a response that appropriately completes the request.\n"
prompt += "Instruction:\nHuman: {}\nAssistant: ".format(query) prompt += "Instruction:\n"
for old_query, response in history:
prompt += "Human: {}\nAssistant: {}\n".format(old_query, response)
prompt += "Human: {}\nAssistant:".format(query)
return prompt return prompt
def format_example_ziya(query, history):
prompt = ""
for old_query, response in history:
prompt += "<human>: {}\n<bot>: {}\n".format(old_query, response)
prompt += "<human>: {}\n<bot>:".format(query)
return prompt
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya
def predict(input, chatbot, max_length, top_p, temperature, history): def predict(input, chatbot, max_length, top_p, temperature, history):
chatbot.append((parse_text(input), "")) chatbot.append((parse_text(input), ""))
input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"] input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"]
input_ids = input_ids.to(model.device) input_ids = input_ids.to(model.device)
gen_kwargs = { gen_kwargs = {
"do_sample": True, "do_sample": True,

Loading…
Cancel
Save