diff --git a/src/cli_demo.py b/src/cli_demo.py
index fd24e99..441e6ab 100644
--- a/src/cli_demo.py
+++ b/src/cli_demo.py
@@ -29,8 +29,8 @@ def main():
         return prompt
 
     def predict(query, history: list):
-        inputs = tokenizer([format_example(query)], return_tensors="pt")
-        inputs = inputs.to(model.device)
+        input_ids = tokenizer([format_example(query)], return_tensors="pt")["input_ids"]
+        input_ids = input_ids.to(model.device)
         gen_kwargs = {
             "do_sample": True,
             "top_p": 0.9,
@@ -41,8 +41,8 @@ def main():
             "repetition_penalty": 1.5
         }
         with torch.no_grad():
-            generation_output = model.generate(**inputs, **gen_kwargs)
-        outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
+            generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
+        outputs = generation_output.tolist()[0][len(input_ids[0]):]
         response = tokenizer.decode(outputs, skip_special_tokens=True)
         history = history + [(query, response)]
         return response, history
diff --git a/src/web_demo.py b/src/web_demo.py
index 5129ea8..ca76659 100644
--- a/src/web_demo.py
+++ b/src/web_demo.py
@@ -83,8 +83,8 @@ def format_example(query):
 def predict(input, chatbot, max_length, top_p, temperature, history):
     chatbot.append((parse_text(input), ""))
 
-    inputs = tokenizer([format_example(input)], return_tensors="pt")
-    inputs = inputs.to(model.device)
+    input_ids = tokenizer([format_example(input)], return_tensors="pt")["input_ids"]
+    input_ids = input_ids.to(model.device)
     gen_kwargs = {
         "do_sample": True,
         "top_p": top_p,
@@ -94,8 +94,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
         "repetition_penalty": 1.0
     }
     with torch.no_grad():
-        generation_output = model.generate(**inputs, **gen_kwargs)
-    outputs = generation_output.tolist()[0][len(inputs["input_ids"][0]):]
+        generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
+    outputs = generation_output.tolist()[0][len(input_ids[0]):]
     response = tokenizer.decode(outputs, skip_special_tokens=True)
     history = history + [(input, response)]
     chatbot[-1] = (parse_text(input), parse_text(response))