|
import importlib.metadata |
|
import torch |
|
import logging |
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
log = logging.getLogger(__name__) |
|
|
|
def check_diffusers_version(): |
|
try: |
|
version = importlib.metadata.version('diffusers') |
|
required_version = '0.31.0' |
|
if version < required_version: |
|
raise AssertionError(f"diffusers version {version} is installed, but version {required_version} or higher is required.") |
|
except importlib.metadata.PackageNotFoundError: |
|
raise AssertionError("diffusers is not installed.") |
|
|
|
def remove_specific_blocks(model, block_indices_to_remove): |
|
import torch.nn as nn |
|
transformer_blocks = model.transformer_blocks |
|
new_blocks = [block for i, block in enumerate(transformer_blocks) if i not in block_indices_to_remove] |
|
model.transformer_blocks = nn.ModuleList(new_blocks) |
|
|
|
return model |
|
|
|
def print_memory(device): |
|
memory = torch.cuda.memory_allocated(device) / 1024**3 |
|
max_memory = torch.cuda.max_memory_allocated(device) / 1024**3 |
|
max_reserved = torch.cuda.max_memory_reserved(device) / 1024**3 |
|
log.info(f"Allocated memory: {memory=:.3f} GB") |
|
log.info(f"Max allocated memory: {max_memory=:.3f} GB") |
|
log.info(f"Max reserved memory: {max_reserved=:.3f} GB") |