diff --git a/src/api_demo.py b/src/api_demo.py index ac761aa..81763cd 100644 --- a/src/api_demo.py +++ b/src/api_demo.py @@ -29,7 +29,7 @@ from utils import ( def torch_gc(): - if not torch.cuda.is_available(): + if torch.cuda.is_available(): num_gpus = torch.cuda.device_count() for device_id in range(num_gpus): with torch.cuda.device(device_id):