jbilcke-hf's picture
jbilcke-hf HF Staff
initial commit log 🪵🦫
91fb4ef
import gc
import inspect
from typing import Optional, Tuple, Union
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.utils.torch_utils import is_compiled_module
logger = get_logger(__name__)
def get_optimizer(
params_to_optimize,
optimizer_name: str = "adam",
learning_rate: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.95,
beta3: float = 0.98,
epsilon: float = 1e-8,
weight_decay: float = 1e-4,
prodigy_decouple: bool = False,
prodigy_use_bias_correction: bool = False,
prodigy_safeguard_warmup: bool = False,
use_8bit: bool = False,
use_4bit: bool = False,
use_torchao: bool = False,
use_deepspeed: bool = False,
use_cpu_offload_optimizer: bool = False,
offload_gradients: bool = False,
) -> torch.optim.Optimizer:
optimizer_name = optimizer_name.lower()
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=learning_rate,
betas=(beta1, beta2),
eps=epsilon,
weight_decay=weight_decay,
)
if use_8bit and use_4bit:
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
try:
import torchao
torchao.__version__
except ImportError:
raise ImportError(
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
)
if not use_torchao and use_4bit:
raise ValueError("4-bit Optimizers are only supported with torchao.")
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
if optimizer_name not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
)
optimizer_name = "adamw"
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
if use_8bit:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if optimizer_name == "adamw":
if use_torchao:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
else:
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "adam":
if use_torchao:
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
else:
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
init_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"beta3": beta3,
"eps": epsilon,
"weight_decay": weight_decay,
"decouple": prodigy_decouple,
"use_bias_correction": prodigy_use_bias_correction,
"safeguard_warmup": prodigy_safeguard_warmup,
}
elif optimizer_name == "came":
try:
import came_pytorch
except ImportError:
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
optimizer_class = came_pytorch.CAME
init_kwargs = {
"lr": learning_rate,
"eps": (1e-30, 1e-16),
"betas": (beta1, beta2, beta3),
"weight_decay": weight_decay,
}
if use_cpu_offload_optimizer:
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
init_kwargs.update({"fused": True})
optimizer = CPUOffloadOptimizer(
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
)
else:
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
return optimizer
def get_gradient_norm(parameters):
norm = 0
for param in parameters:
if param.grad is None:
continue
local_norm = param.grad.detach().data.norm(2)
norm += local_norm.item() ** 2
norm = norm**0.5
return norm
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
patch_size_t: int = None,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
if patch_size_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def reset_memory(device: Union[str, torch.device]) -> None:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device: Union[str, torch.device]) -> None:
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory_allocated=:.3f} GB")
print(f"{max_memory_allocated=:.3f} GB")
print(f"{max_memory_reserved=:.3f} GB")
def unwrap_model(accelerator: Accelerator, model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model