|
@ -3,11 +3,12 @@ |
|
|
# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint |
|
|
# Usage: python web_demo.py --checkpoint_dir path_to_checkpoint |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
import mdtex2html |
|
|
import mdtex2html |
|
|
import gradio as gr |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
|
|
|
from threading import Thread |
|
|
from utils import load_pretrained, prepare_infer_args, get_logits_processor |
|
|
from utils import load_pretrained, prepare_infer_args, get_logits_processor |
|
|
|
|
|
from transformers import TextIteratorStreamer |
|
|
from transformers.utils.versions import require_version |
|
|
from transformers.utils.versions import require_version |
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -83,6 +84,7 @@ def format_example_ziya(query, history): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya |
|
|
format_example = format_example_alpaca if data_args.prompt_template == "alpaca" else format_example_ziya |
|
|
|
|
|
streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict(input, chatbot, max_length, top_p, temperature, history): |
|
|
def predict(input, chatbot, max_length, top_p, temperature, history): |
|
@ -97,12 +99,14 @@ def predict(input, chatbot, max_length, top_p, temperature, history): |
|
|
"num_beams": 1, |
|
|
"num_beams": 1, |
|
|
"max_length": max_length, |
|
|
"max_length": max_length, |
|
|
"repetition_penalty": 1.0, |
|
|
"repetition_penalty": 1.0, |
|
|
"logits_processor": get_logits_processor() |
|
|
"logits_processor": get_logits_processor(), |
|
|
|
|
|
"streamer": streamer |
|
|
} |
|
|
} |
|
|
with torch.no_grad(): |
|
|
thread = Thread(target=model.generate, kwargs=gen_kwargs) |
|
|
generation_output = model.generate(input_ids=input_ids, **gen_kwargs) |
|
|
thread.start() |
|
|
outputs = generation_output.tolist()[0][len(input_ids[0]):] |
|
|
response = "" |
|
|
response = tokenizer.decode(outputs, skip_special_tokens=True) |
|
|
for new_text in streamer: |
|
|
|
|
|
response += new_text |
|
|
history = history + [(input, response)] |
|
|
history = history + [(input, response)] |
|
|
chatbot[-1] = (parse_text(input), parse_text(response)) |
|
|
chatbot[-1] = (parse_text(input), parse_text(response)) |
|
|
yield chatbot, history |
|
|
yield chatbot, history |
|
@ -129,7 +133,7 @@ with gr.Blocks() as demo: |
|
|
submitBtn = gr.Button("Submit", variant="primary") |
|
|
submitBtn = gr.Button("Submit", variant="primary") |
|
|
with gr.Column(scale=1): |
|
|
with gr.Column(scale=1): |
|
|
emptyBtn = gr.Button("Clear History") |
|
|
emptyBtn = gr.Button("Clear History") |
|
|
max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) |
|
|
max_length = gr.Slider(0, 2048, value=1024, step=1.0, label="Maximum length", interactive=True) |
|
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) |
|
|
top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) |
|
|
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) |
|
|
temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) |
|
|
|
|
|
|
|
|