|
|
@ -12,7 +12,7 @@ from transformers import TextIteratorStreamer |
|
|
|
from transformers.utils.versions import require_version |
|
|
|
|
|
|
|
|
|
|
|
require_version("gradio==3.27.0", "To fix: pip install gradio==3.27.0") # higher version may cause problems |
|
|
|
require_version("gradio>=3.30.0", "To fix: pip install gradio>=3.30.0") |
|
|
|
model_args, data_args, finetuning_args = prepare_infer_args() |
|
|
|
model, tokenizer = load_pretrained(model_args, finetuning_args) |
|
|
|
|
|
|
@ -93,6 +93,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history): |
|
|
|
input_ids = tokenizer([format_example(input, history)], return_tensors="pt")["input_ids"] |
|
|
|
input_ids = input_ids.to(model.device) |
|
|
|
gen_kwargs = { |
|
|
|
"input_ids": input_ids, |
|
|
|
"do_sample": True, |
|
|
|
"top_p": top_p, |
|
|
|
"temperature": temperature, |
|
|
@ -107,9 +108,9 @@ def predict(input, chatbot, max_length, top_p, temperature, history): |
|
|
|
response = "" |
|
|
|
for new_text in streamer: |
|
|
|
response += new_text |
|
|
|
history = history + [(input, response)] |
|
|
|
new_history = history + [(input, response)] |
|
|
|
chatbot[-1] = (parse_text(input), parse_text(response)) |
|
|
|
yield chatbot, history |
|
|
|
yield chatbot, new_history |
|
|
|
|
|
|
|
|
|
|
|
def reset_user_input(): |
|
|
@ -145,4 +146,4 @@ with gr.Blocks() as demo: |
|
|
|
|
|
|
|
emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) |
|
|
|
|
|
|
|
demo.queue().launch(server_name="0.0.0.0", share=False, inbrowser=True) |
|
|
|
demo.queue().launch(server_name="0.0.0.0", share=True, inbrowser=True) |
|
|
|