Spaces:
Running
on
Zero
Running
on
Zero
""" | |
----------------------------------------------------------------------------- | |
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | |
NVIDIA CORPORATION and its licensors retain all intellectual property | |
and proprietary rights in and to this software, related documentation | |
and any modifications thereto. Any use, reproduction, disclosure or | |
distribution of this software and related documentation without an express | |
license agreement from NVIDIA CORPORATION is strictly prohibited. | |
----------------------------------------------------------------------------- | |
""" | |
import numpy as np | |
import torch | |
class FlowMatchingScheduler: | |
def __init__(self, num_train_timesteps: int = 1000, shift: float = 1): | |
# set timesteps | |
self.num_train_timesteps = num_train_timesteps | |
self.shift = shift | |
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() | |
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) | |
sigmas = timesteps / num_train_timesteps | |
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) | |
self.sigmas = sigmas # 1 --> 0 | |
self.timesteps = sigmas * num_train_timesteps # num_train_timesteps --> 1 | |
# set device | |
def to(self, device): | |
self.sigmas = self.sigmas.to(device=device) | |
self.timesteps = self.timesteps.to(device=device) | |
# add random noise to latent during training | |
def add_noise(self, latent: torch.Tensor, logit_mean: float = 1.0, logit_std: float = 1.0): | |
# latent: [B, ...] | |
# timesteps: [B] | |
# return: [B, ...] noisy_latent, [B, ...] noise, [B] timesteps | |
# logit-normal sampling | |
u = torch.normal(mean=logit_mean, std=logit_std, size=(latent.shape[0],), device=self.sigmas.device) | |
u = torch.nn.functional.sigmoid(u) | |
step_indices = (u * self.num_train_timesteps).long() | |
timesteps = self.timesteps[step_indices] | |
sigmas = self.sigmas[step_indices].flatten() | |
while len(sigmas.shape) < latent.ndim: | |
sigmas = sigmas.unsqueeze(-1) | |
noise = torch.randn_like(latent) | |
noisy_latent = (1.0 - sigmas) * latent + sigmas * noise | |
return noisy_latent, noise, timesteps | |