diff --git a/src/cli_demo.py b/src/cli_demo.py
index 3da88aa..4426200 100644
--- a/src/cli_demo.py
+++ b/src/cli_demo.py
@@ -4,7 +4,7 @@
 
 
 import torch
-from utils import ModelArguments, FinetuningArguments, load_pretrained
+from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
 from transformers import HfArgumentParser
 
 
@@ -35,12 +35,12 @@ def main():
         input_ids = input_ids.to(model.device)
         gen_kwargs = {
             "do_sample": True,
-            "top_p": 0.9,
-            "top_k": 40,
-            "temperature": 0.7,
+            "top_p": 0.7,
+            "temperature": 0.95,
             "num_beams": 1,
             "max_new_tokens": 256,
-            "repetition_penalty": 1.5
+            "repetition_penalty": 1.5,
+            "logits_processor": get_logits_processor()
         }
         with torch.no_grad():
             generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
diff --git a/src/utils/common.py b/src/utils/common.py
index 9009906..65ae293 100644
--- a/src/utils/common.py
+++ b/src/utils/common.py
@@ -46,7 +46,8 @@ from .other import (
 )
 
 check_min_version("4.29.1")
-require_version("datasets>=2.10.0", "To fix: pip install datasets>=2.10.0")
+require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
+require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
 require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
 require_version("trl>=0.4.1", "To fix: pip install trl>=0.4.1")
 
@@ -84,8 +85,7 @@ def init_adapter(
                 param.data = param.data.to(torch.float32)
 
     if finetuning_args.finetuning_type != "lora" and model_args.checkpoint_dir is not None:
-        if len(model_args.checkpoint_dir) > 1:
-            logger.warning("Only LoRA tuning accepts multiple checkpoints.")
+        assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
         load_trainable_params(model, model_args.checkpoint_dir[0]) # load model checkpoints for non-peft methods
 
     if finetuning_args.finetuning_type == "lora":
@@ -154,8 +154,7 @@ def load_pretrained(
     config_kwargs = {}
     if model_args.quantization_bit is not None:
         assert model_args.quantization_bit == 8, "We only accept 8-bit quantization."
-
-        require_version("bitsandbytes>=0.37.0", "bitsandbytes library is required to use this feature.")
+        require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.1")
         from bitsandbytes.cuda_setup.main import get_compute_capability, get_cuda_lib_handle, is_cublasLt_compatible
         cuda = get_cuda_lib_handle()
         cc = get_compute_capability(cuda)
@@ -179,7 +178,6 @@ def load_pretrained(
 
     if not is_trainable:
         model.requires_grad_(False) # fix all model params
-        model = model.half() # cast all params to float16 for inference
 
     if stage == "rm" or stage == "ppo": # add value head
         model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
diff --git a/src/utils/config.py b/src/utils/config.py
index 22b66fb..98d5907 100644
--- a/src/utils/config.py
+++ b/src/utils/config.py
@@ -49,6 +49,14 @@ class ModelArguments:
         default=None,
         metadata={"help": "The number of bits to quantize the model."}
     )
+    quantization_type: Optional[Literal["fp4", "nf4"]] = field(
+        default="nf4",
+        metadata={"help": "Quantization data type to use."}
+    )
+    double_quantization: Optional[bool] = field(
+        default=True,
+        metadata={"help": "Compress the quantization statistics through double quantization."}
+    )
     checkpoint_dir: Optional[str] = field(
         default=None,
         metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
@@ -206,14 +214,14 @@ class FinetuningArguments:
         assert self.finetuning_type in ["none", "freeze", "lora", "full"], "Invalid fine-tuning method."
 
     def save_to_json(self, json_path: str):
-        """Save the content of this instance in JSON format inside `json_path`."""
+        """Saves the content of this instance in JSON format inside `json_path`."""
         json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
         with open(json_path, "w", encoding="utf-8") as f:
             f.write(json_string)
 
     @classmethod
     def load_from_json(cls, json_path: str):
-        """Create an instance from the content of `json_path`."""
+        """Creates an instance from the content of `json_path`."""
         with open(json_path, "r", encoding="utf-8") as f:
             text = f.read()
         return cls(**json.loads(text))
diff --git a/src/utils/ppo.py b/src/utils/ppo.py
index 5e754e4..701d4b4 100644
--- a/src/utils/ppo.py
+++ b/src/utils/ppo.py
@@ -9,7 +9,6 @@ from transformers.modeling_utils import PreTrainedModel
 
 from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
 from trl.core import LengthSampler
-from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
 
 from .peft_trainer import PeftTrainer, LogCallback
 
diff --git a/src/web_demo.py b/src/web_demo.py
index 83ccdf9..7445d0e 100644
--- a/src/web_demo.py
+++ b/src/web_demo.py
@@ -7,7 +7,7 @@ import torch
 import mdtex2html
 import gradio as gr
 
-from utils import ModelArguments, FinetuningArguments, load_pretrained
+from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
 from transformers import HfArgumentParser
 from transformers.utils.versions import require_version
 
@@ -93,7 +93,8 @@ def predict(input, chatbot, max_length, top_p, temperature, history):
         "temperature": temperature,
         "num_beams": 1,
         "max_length": max_length,
-        "repetition_penalty": 1.0
+        "repetition_penalty": 1.5,
+        "logits_processor": get_logits_processor()
     }
     with torch.no_grad():
         generation_output = model.generate(input_ids=input_ids, **gen_kwargs)