|
|
@ -83,8 +83,8 @@ def format_example(query): |
|
|
|
def predict(input, chatbot, max_length, top_p, temperature, history): |
|
|
|
chatbot.append((parse_text(input), "")) |
|
|
|
|
|
|
|
inputs = tokenizer([format_example(input)], return_tensors="pt") |
|
|
|
inputs = inputs.to(model.device) |
|
|
|
input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"] |
|
|
|
input_ids = input_ids.to(model.device) |
|
|
|
gen_kwargs = { |
|
|
|
"do_sample": True, |
|
|
|
"top_p": top_p, |
|
|
@ -94,8 +94,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history): |
|
|
|
"repetition_penalty": 1.0 |
|
|
|
} |
|
|
|
with torch.no_grad(): |
|
|
|
generation_output = model.generate(**inputs, **gen_kwargs) |
|
|
|
outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):] |
|
|
|
generation_output = model.generate(input_ids=input_ids, **gen_kwargs) |
|
|
|
outputs = generation_output.tolist()[0][len(input_ids[0]):] |
|
|
|
response = tokenizer.decode(outputs, skip_special_tokens=True) |
|
|
|
history = history + [(input, response)] |
|
|
|
chatbot[-1] = (parse_text(input), parse_text(response)) |
|
|
|