Spaces:
Runtime error
Runtime error
import gc | |
import inspect | |
from typing import Optional, Tuple, Union | |
import torch | |
logger = get_logger(__name__) | |
def reset_memory(device: Union[str, torch.device]) -> None: | |
gc.collect() | |
torch.cuda.empty_cache() | |
torch.cuda.reset_peak_memory_stats(device) | |
torch.cuda.reset_accumulated_memory_stats(device) | |
def print_memory(device: Union[str, torch.device]) -> None: | |
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3 | |
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3 | |
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 | |
print(f"{memory_allocated=:.3f} GB") | |
print(f"{max_memory_allocated=:.3f} GB") | |
print(f"{max_memory_reserved=:.3f} GB") | |