jbilcke-hf's picture
jbilcke-hf HF Staff
initial commit log 🪵🦫
91fb4ef
raw
history blame contribute delete
733 Bytes
import gc
import inspect
from typing import Optional, Tuple, Union
import torch
logger = get_logger(__name__)
def reset_memory(device: Union[str, torch.device]) -> None:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device: Union[str, torch.device]) -> None:
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory_allocated=:.3f} GB")
print(f"{max_memory_allocated=:.3f} GB")
print(f"{max_memory_reserved=:.3f} GB")