|
@ -29,7 +29,7 @@ from utils import ( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def torch_gc(): |
|
|
def torch_gc(): |
|
|
if not torch.cuda.is_available(): |
|
|
if torch.cuda.is_available(): |
|
|
num_gpus = torch.cuda.device_count() |
|
|
num_gpus = torch.cuda.device_count() |
|
|
for device_id in range(num_gpus): |
|
|
for device_id in range(num_gpus): |
|
|
with torch.cuda.device(device_id): |
|
|
with torch.cuda.device(device_id): |
|
|