Spaces:
Runtime error
Runtime error
File size: 8,665 Bytes
91fb4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
import gc
import inspect
from typing import Optional, Tuple, Union
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.utils.torch_utils import is_compiled_module
logger = get_logger(__name__)
def get_optimizer(
params_to_optimize,
optimizer_name: str = "adam",
learning_rate: float = 1e-3,
beta1: float = 0.9,
beta2: float = 0.95,
beta3: float = 0.98,
epsilon: float = 1e-8,
weight_decay: float = 1e-4,
prodigy_decouple: bool = False,
prodigy_use_bias_correction: bool = False,
prodigy_safeguard_warmup: bool = False,
use_8bit: bool = False,
use_4bit: bool = False,
use_torchao: bool = False,
use_deepspeed: bool = False,
use_cpu_offload_optimizer: bool = False,
offload_gradients: bool = False,
) -> torch.optim.Optimizer:
optimizer_name = optimizer_name.lower()
# Use DeepSpeed optimzer
if use_deepspeed:
from accelerate.utils import DummyOptim
return DummyOptim(
params_to_optimize,
lr=learning_rate,
betas=(beta1, beta2),
eps=epsilon,
weight_decay=weight_decay,
)
if use_8bit and use_4bit:
raise ValueError("Cannot set both `use_8bit` and `use_4bit` to True.")
if (use_torchao and (use_8bit or use_4bit)) or use_cpu_offload_optimizer:
try:
import torchao
torchao.__version__
except ImportError:
raise ImportError(
"To use optimizers from torchao, please install the torchao library: `USE_CPP=0 pip install torchao`."
)
if not use_torchao and use_4bit:
raise ValueError("4-bit Optimizers are only supported with torchao.")
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy", "came"]
if optimizer_name not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {optimizer_name}. Supported optimizers include {supported_optimizers}. Defaulting to `AdamW`."
)
optimizer_name = "adamw"
if (use_8bit or use_4bit) and optimizer_name not in ["adam", "adamw"]:
raise ValueError("`use_8bit` and `use_4bit` can only be used with the Adam and AdamW optimizers.")
if use_8bit:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
if optimizer_name == "adamw":
if use_torchao:
from torchao.prototype.low_bit_optim import AdamW4bit, AdamW8bit
optimizer_class = AdamW8bit if use_8bit else AdamW4bit if use_4bit else torch.optim.AdamW
else:
optimizer_class = bnb.optim.AdamW8bit if use_8bit else torch.optim.AdamW
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "adam":
if use_torchao:
from torchao.prototype.low_bit_optim import Adam4bit, Adam8bit
optimizer_class = Adam8bit if use_8bit else Adam4bit if use_4bit else torch.optim.Adam
else:
optimizer_class = bnb.optim.Adam8bit if use_8bit else torch.optim.Adam
init_kwargs = {
"betas": (beta1, beta2),
"eps": epsilon,
"weight_decay": weight_decay,
}
elif optimizer_name == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
)
init_kwargs = {
"lr": learning_rate,
"betas": (beta1, beta2),
"beta3": beta3,
"eps": epsilon,
"weight_decay": weight_decay,
"decouple": prodigy_decouple,
"use_bias_correction": prodigy_use_bias_correction,
"safeguard_warmup": prodigy_safeguard_warmup,
}
elif optimizer_name == "came":
try:
import came_pytorch
except ImportError:
raise ImportError("To use CAME, please install the came-pytorch library: `pip install came-pytorch`")
optimizer_class = came_pytorch.CAME
init_kwargs = {
"lr": learning_rate,
"eps": (1e-30, 1e-16),
"betas": (beta1, beta2, beta3),
"weight_decay": weight_decay,
}
if use_cpu_offload_optimizer:
from torchao.prototype.low_bit_optim import CPUOffloadOptimizer
if "fused" in inspect.signature(optimizer_class.__init__).parameters:
init_kwargs.update({"fused": True})
optimizer = CPUOffloadOptimizer(
params_to_optimize, optimizer_class=optimizer_class, offload_gradients=offload_gradients, **init_kwargs
)
else:
optimizer = optimizer_class(params_to_optimize, **init_kwargs)
return optimizer
def get_gradient_norm(parameters):
norm = 0
for param in parameters:
if param.grad is None:
continue
local_norm = param.grad.detach().data.norm(2)
norm += local_norm.item() ** 2
norm = norm**0.5
return norm
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def prepare_rotary_positional_embeddings(
height: int,
width: int,
num_frames: int,
vae_scale_factor_spatial: int = 8,
patch_size: int = 2,
patch_size_t: int = None,
attention_head_dim: int = 64,
device: Optional[torch.device] = None,
base_height: int = 480,
base_width: int = 720,
) -> Tuple[torch.Tensor, torch.Tensor]:
grid_height = height // (vae_scale_factor_spatial * patch_size)
grid_width = width // (vae_scale_factor_spatial * patch_size)
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
if patch_size_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid(
(grid_height, grid_width), base_size_width, base_size_height
)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
)
freqs_cos = freqs_cos.to(device=device)
freqs_sin = freqs_sin.to(device=device)
return freqs_cos, freqs_sin
def reset_memory(device: Union[str, torch.device]) -> None:
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
torch.cuda.reset_accumulated_memory_stats(device)
def print_memory(device: Union[str, torch.device]) -> None:
memory_allocated = torch.cuda.memory_allocated(device) / 1024**3
max_memory_allocated = torch.cuda.max_memory_allocated(device) / 1024**3
max_memory_reserved = torch.cuda.max_memory_reserved(device) / 1024**3
print(f"{memory_allocated=:.3f} GB")
print(f"{max_memory_allocated=:.3f} GB")
print(f"{max_memory_reserved=:.3f} GB")
def unwrap_model(accelerator: Accelerator, model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
|