diff --git a/src/cli_demo.py b/src/cli_demo.py index 6ee12ce..90e0e7b 100644 --- a/src/cli_demo.py +++ b/src/cli_demo.py @@ -53,7 +53,7 @@ def main(): "temperature": 0.95, "num_beams": 1, "max_new_tokens": 256, - "repetition_penalty": 1.5, + "repetition_penalty": 1.0, "logits_processor": get_logits_processor() } with torch.no_grad(): diff --git a/src/web_demo.py b/src/web_demo.py index 426fe52..5cd05c3 100644 --- a/src/web_demo.py +++ b/src/web_demo.py @@ -105,7 +105,7 @@ def predict(input, chatbot, max_length, top_p, temperature, history): "temperature": temperature, "num_beams": 1, "max_length": max_length, - "repetition_penalty": 1.5, + "repetition_penalty": 1.0, "logits_processor": get_logits_processor() } with torch.no_grad():