Spaces:
Runtime error
Runtime error
File size: 733 Bytes
91fb4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import gc
import inspect
from typing import Optional, Tuple, Union
import torch
logger = get_logger(__name__)
def reset_memory(device: Union[str, torch.device]) -> None:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device: Union[str, torch.device]) -> None:
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory_allocated=:.3f} GB")
print(f"{max_memory_allocated=:.3f} GB")
print(f"{max_memory_reserved=:.3f} GB")
|