import gc import inspect from typing import Optional, Tuple, Union import torch logger = get_logger(__name__) def reset_memory(device: Union[str, torch.device]) -> None: gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) torch.cuda.reset_accumulated_memory_stats(device) def print_memory(device: Union[str, torch.device]) -> None: memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 print(f"{memory_allocated=:.3f} GB") print(f"{max_memory_allocated=:.3f} GB") print(f"{max_memory_reserved=:.3f} GB")