From 5d021d4ad514974dd9dcc5240871713cf53a87f2 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Wed, 7 Jun 2023 10:52:35 +0800 Subject: [PATCH] fix inference, add prompt template --- src/api_demo.py | 23 ++++------------------- src/cli_demo.py | 22 ++++------------------ src/utils/__init__.py | 2 ++ src/utils/common.py | 32 +++++++++----------------------- src/utils/template.py | 16 ++++++++++++++++ src/web_demo.py | 30 +++++++++--------------------- 6 files changed, 44 insertions(+), 81 deletions(-) create mode 100644 src/utils/template.py diff --git a/src/api_demo.py b/src/api_demo.py index ca5e05d..678b5a5 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -23,7 +23,9 @@ from fastapi import FastAPI, Request from utils import ( load_pretrained, prepare_infer_args, - get_logits_processor + get_logits_processor, + prompt_template_alpaca, + prompt_template_ziya ) @@ -96,23 +98,6 @@ async def create_item(request: Request): if __name__ == "__main__": model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) - - def format_example_alpaca(query, history): - prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\n" - for old_query, response in history: - prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) - prompt += "Human: {}\nAssistant:".format(query) - return prompt - - def format_example_ziya(query, history): - prompt = "" - for old_query, response in history: - prompt += ": {}\n: {}\n".format(old_query, response) - prompt += ": {}\n:".format(query) - return prompt - - format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya + format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) diff --git a/src/cli_demo.py b/src/cli_demo.py index 9aae2f7..fd0a1c1 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -6,7 +6,9 @@ from utils import ( load_pretrained, prepare_infer_args, - get_logits_processor + get_logits_processor, + prompt_template_alpaca, + prompt_template_ziya ) from threading import Thread from transformers import TextIteratorStreamer @@ -18,23 +20,7 @@ def main(): model_name = "BLOOM" if "bloom" in model_args.model_name_or_path else "LLaMA" model, tokenizer = load_pretrained(model_args, finetuning_args) - def format_example_alpaca(query, history): - prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\n" - for old_query, response in history: - prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) - prompt += "Human: {}\nAssistant:".format(query) - return prompt - - def format_example_ziya(query, history): - prompt = "" - for old_query, response in history: - prompt += ": {}\n: {}\n".format(old_query, response) - prompt += ": {}\n:".format(query) - return prompt - - format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya + format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) def predict_and_print(query, history: list): diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 69c69c4..152052e 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -14,4 +14,6 @@ from .seq2seq import ComputeMetrics, Seq2SeqPeftTrainer from .pairwise import PairwiseDataCollatorWithPadding, PairwisePeftTrainer from .ppo import PPOPeftTrainer +from .template import prompt_template_alpaca, prompt_template_ziya + from .other import get_logits_processor, plot_loss diff --git a/src/utils/common.py b/src/utils/common.py index ffbffe4..9354f18 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -37,6 +37,11 @@ from .config import ( FinetuningArguments ) +from .template import ( + prompt_template_alpaca, + prompt_template_ziya +) + from .other import ( get_logger, load_trainable_params, @@ -224,6 +229,7 @@ def load_pretrained( if not is_trainable: model.requires_grad_(False) # fix all model params + model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16 print_trainable_params(model) @@ -395,39 +401,19 @@ def preprocess_data( column_names = list(dataset.column_names) prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + prompt_template = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya # support question with a single answer or multiple answers - def format_example_alpaca(examples): - for i in range(len(examples["prompt"])): - if examples["prompt"][i] and examples["response"][i]: - query, answer = examples["prompt"][i], examples["response"][i] - if examples["query"][i]: - query += "\n" + examples["query"][i] - prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\n" + prefix - if examples["history"][i]: - for old_query, response in examples["history"][i]: - prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) - prompt += "Human: {}\nAssistant: ".format(query) - yield prompt, answer - - def format_example_ziya(examples): + def format_example(examples): for i in range(len(examples["prompt"])): if examples["prompt"][i] and examples["response"][i]: query, answer = examples["prompt"][i], examples["response"][i] if examples["query"][i]: query += "\n" + examples["query"][i] - prompt = "" - if examples["history"][i]: - for old_query, response in examples["history"][i]: - prompt += ": {}\n: {}\n".format(old_query, response) - prompt += ": {}\n:".format(query) + prompt = prompt_template(query, examples["history"][i]) prompt = prefix + prompt yield prompt, answer - format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya - def preprocess_pretrain_dataset(examples): # build grouped texts with format ` X1 X2 X3 ...` (without ) text_ids = tokenizer(examples["prompt"])["input_ids"] diff --git a/src/utils/template.py b/src/utils/template.py new file mode 100644 index 0000000..2cafd5e --- /dev/null +++ b/src/utils/template.py @@ -0,0 +1,16 @@ +def prompt_template_alpaca(query, history=None): + prompt = "" + if history: + for old_query, response in history: + prompt += "Human:{}\nAssistant:{}\n".format(old_query, response) + prompt += "Human:{}\nAssistant:".format(query) + return prompt + + +def prompt_template_ziya(query, history=None): + prompt = "" + if history: + for old_query, response in history: + prompt += ":{}\n:{}\n".format(old_query, response) + prompt += ":{}\n:".format(query) + return prompt diff --git a/src/web_demo.py b/src/web_demo.py index 54bf634..9b69c3e 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -7,7 +7,14 @@ import mdtex2html import gradio as gr from threading import Thread -from utils import load_pretrained, prepare_infer_args, get_logits_processor +from utils import ( + load_pretrained, + prepare_infer_args, + get_logits_processor, + prompt_template_alpaca, + prompt_template_ziya +) + from transformers import TextIteratorStreamer from transformers.utils.versions import require_version @@ -18,26 +25,7 @@ require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") model_args, data_args, finetuning_args = prepare_infer_args() model, tokenizer = load_pretrained(model_args, finetuning_args) - -def format_example_alpaca(query, history): - prompt = "Below is an instruction that describes a task. " - prompt += "Write a response that appropriately completes the request.\n" - prompt += "Instruction:\n" - for old_query, response in history: - prompt += "Human: {}\nAssistant: {}\n".format(old_query, response) - prompt += "Human: {}\nAssistant:".format(query) - return prompt - - -def format_example_ziya(query, history): - prompt = "" - for old_query, response in history: - prompt += ": {}\n: {}\n".format(old_query, response) - prompt += ": {}\n:".format(query) - return prompt - - -format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya +format_example = prompt_template_alpaca if data_args.prompt_template == "alpaca" else prompt_template_ziya streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)