From fe1d9308163699b7c4dd791915788855b2e6854f Mon Sep 17 00:00:00 2001 From: hiyouga Date: Mon, 5 Jun 2023 16:43:44 +0800 Subject: [PATCH] implement stream generating --- src/cli_demo.py | 27 +++++++++++++++++---------- src/utils/other.py | 3 +-- src/web_demo.py | 24 ++++++++++++++---------- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/src/cli_demo.py b/src/cli_demo.py index 72091d0..1bac015 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -3,12 +3,13 @@ # Usage: python cli_demo.py --checkpoint_dir path_to_checkpoint -import torch from utils import ( load_pretrained, prepare_infer_args, get_logits_processor ) +from threading import Thread +from transformers import TextIteratorStreamer def main(): @@ -34,25 +35,32 @@ def main(): return prompt format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya + streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) - def predict(query, history: list): + def predict_and_print(query, history: list): input_ids = tokenizer([format_example(query, history)], return_tensors="pt")["input_ids"] input_ids = input_ids.to(model.device) gen_kwargs = { + "input_ids": input_ids, "do_sample": True, "top_p": 0.7, "temperature": 0.95, "num_beams": 1, "max_new_tokens": 256, "repetition_penalty": 1.0, - "logits_processor": get_logits_processor() + "logits_processor": get_logits_processor(), + "streamer": streamer } - with torch.no_grad(): - generation_output = model.generate(input_ids=input_ids, **gen_kwargs) - outputs = generation_output.tolist()[0][len(input_ids[0]):] - response = tokenizer.decode(outputs, skip_special_tokens=True) + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + response = "" + print("{}: ".format(model_name), end="") + for new_text in streamer: + print(new_text, end="", flush=True) + response += new_text + print() history = history + [(query, response)] - return response, history + return history history = [] print("欢迎使用 {} 模型,输入内容即可对话,clear清空对话历史,stop终止程序".format(model_name)) @@ -73,8 +81,7 @@ def main(): print("History has been removed.") continue - response, history = predict(query, history) - print("{}:".format(model_name), response) + history = predict_and_print(query, history) if __name__ == "__main__": diff --git a/src/utils/other.py b/src/utils/other.py index 470e2e9..4e2d561 100644 --- a/src/utils/other.py +++ b/src/utils/other.py @@ -52,13 +52,12 @@ class AverageMeter: # Avoid runtime error in model.generate(do_sample=True). -# Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54 class InvalidScoreLogitsProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: if torch.isnan(scores).any() or torch.isinf(scores).any(): scores.zero_() - scores[..., 5] = 5e4 + scores[:, 0] = 1.0 return scores diff --git a/src/web_demo.py b/src/web_demo.py index c5c0ddf..77dd76d 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -3,11 +3,12 @@ # Usage: python web_demo.py --checkpoint_dir path_to_checkpoint -import torch import mdtex2html import gradio as gr +from threading import Thread from utils import load_pretrained, prepare_infer_args, get_logits_processor +from transformers import TextIteratorStreamer from transformers.utils.versions import require_version @@ -83,6 +84,7 @@ def format_example_ziya(query, history): format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya +streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) def predict(input, chatbot, max_length, top_p, temperature, history): @@ -97,15 +99,17 @@ def predict(input, chatbot, max_length, top_p, temperature, history): "num_beams": 1, "max_length": max_length, "repetition_penalty": 1.0, - "logits_processor": get_logits_processor() + "logits_processor": get_logits_processor(), + "streamer": streamer } - with torch.no_grad(): - generation_output = model.generate(input_ids=input_ids, **gen_kwargs) - outputs = generation_output.tolist()[0][len(input_ids[0]):] - response = tokenizer.decode(outputs, skip_special_tokens=True) - history = history + [(input, response)] - chatbot[-1] = (parse_text(input), parse_text(response)) - yield chatbot, history + thread = Thread(target=model.generate, kwargs=gen_kwargs) + thread.start() + response = "" + for new_text in streamer: + response += new_text + history = history + [(input, response)] + chatbot[-1] = (parse_text(input), parse_text(response)) + yield chatbot, history def reset_user_input(): @@ -129,7 +133,7 @@ with gr.Blocks() as demo: submitBtn = gr.Button("Submit", variant="primary") with gr.Column(scale=1): emptyBtn = gr.Button("Clear History") - max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)