Spaces:
Running
Running
File size: 5,899 Bytes
91fb4ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import math
from typing import Optional, Union
import torch
from diffusers import CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler
from diffusers.training_utils import compute_loss_weighting_for_sd3
# Default values copied from https://github.com/huggingface/diffusers/blob/8957324363d8b239d82db4909fbf8c0875683e3d/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py#L47
def resolution_dependent_timestep_flow_shift(
latents: torch.Tensor,
sigmas: torch.Tensor,
base_image_seq_len: int = 256,
max_image_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
) -> torch.Tensor:
image_or_video_sequence_length = 0
if latents.ndim == 4:
image_or_video_sequence_length = latents.shape[2] * latents.shape[3]
elif latents.ndim == 5:
image_or_video_sequence_length = latents.shape[2] * latents.shape[3] * latents.shape[4]
else:
raise ValueError(f"Expected 4D or 5D tensor, got {latents.ndim}D tensor")
m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len)
b = base_shift - m * base_image_seq_len
mu = m * image_or_video_sequence_length + b
sigmas = default_flow_shift(latents, sigmas, shift=mu)
return sigmas
def default_flow_shift(sigmas: torch.Tensor, shift: float = 1.0) -> torch.Tensor:
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
return sigmas
def compute_density_for_timestep_sampling(
weighting_scheme: str,
batch_size: int,
logit_mean: float = None,
logit_std: float = None,
mode_scale: float = None,
device: torch.device = torch.device("cpu"),
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
r"""
Compute the density for sampling the timesteps when doing SD3 training.
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
"""
if weighting_scheme == "logit_normal":
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
u = torch.nn.functional.sigmoid(u)
elif weighting_scheme == "mode":
u = torch.rand(size=(batch_size,), device=device, generator=generator)
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
else:
u = torch.rand(size=(batch_size,), device=device, generator=generator)
return u
def get_scheduler_alphas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
return None
elif isinstance(scheduler, CogVideoXDDIMScheduler):
return scheduler.alphas_cumprod.clone()
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
def get_scheduler_sigmas(scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler]) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
return scheduler.sigmas.clone()
elif isinstance(scheduler, CogVideoXDDIMScheduler):
return scheduler.timesteps.clone().float() / float(scheduler.config.num_train_timesteps)
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
def prepare_sigmas(
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
sigmas: torch.Tensor,
batch_size: int,
num_train_timesteps: int,
flow_weighting_scheme: str = "none",
flow_logit_mean: float = 0.0,
flow_logit_std: float = 1.0,
flow_mode_scale: float = 1.29,
device: torch.device = torch.device("cpu"),
generator: Optional[torch.Generator] = None,
) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
weights = compute_density_for_timestep_sampling(
weighting_scheme=flow_weighting_scheme,
batch_size=batch_size,
logit_mean=flow_logit_mean,
logit_std=flow_logit_std,
mode_scale=flow_mode_scale,
device=device,
generator=generator,
)
indices = (weights * num_train_timesteps).long()
elif isinstance(scheduler, CogVideoXDDIMScheduler):
# TODO(aryan): Currently, only uniform sampling is supported. Add more sampling schemes.
weights = torch.rand(size=(batch_size,), device=device, generator=generator)
indices = (weights * num_train_timesteps).long()
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
return sigmas[indices]
def prepare_loss_weights(
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
alphas: Optional[torch.Tensor] = None,
sigmas: Optional[torch.Tensor] = None,
flow_weighting_scheme: str = "none",
) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
return compute_loss_weighting_for_sd3(sigmas=sigmas, weighting_scheme=flow_weighting_scheme)
elif isinstance(scheduler, CogVideoXDDIMScheduler):
# SNR is computed as (alphas / (1 - alphas)), but for some reason CogVideoX uses 1 / (1 - alphas).
# TODO(aryan): Experiment if using alphas / (1 - alphas) gives better results.
return 1 / (1 - alphas)
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
def prepare_target(
scheduler: Union[CogVideoXDDIMScheduler, FlowMatchEulerDiscreteScheduler],
noise: torch.Tensor,
latents: torch.Tensor,
) -> torch.Tensor:
if isinstance(scheduler, FlowMatchEulerDiscreteScheduler):
target = noise - latents
elif isinstance(scheduler, CogVideoXDDIMScheduler):
target = latents
else:
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
return target
|