PartPacker / flow /flow_matching.py
ashawkey's picture
init
daa6779
raw
history blame
2.25 kB
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""
import numpy as np
import torch
class FlowMatchingScheduler:
def __init__(self, num_train_timesteps: int = 1000, shift: float = 1):
# set timesteps
self.num_train_timesteps = num_train_timesteps
self.shift = shift
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
sigmas = timesteps / num_train_timesteps
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
self.sigmas = sigmas # 1 --> 0
self.timesteps = sigmas * num_train_timesteps # num_train_timesteps --> 1
# set device
def to(self, device):
self.sigmas = self.sigmas.to(device=device)
self.timesteps = self.timesteps.to(device=device)
# add random noise to latent during training
def add_noise(self, latent: torch.Tensor, logit_mean: float = 1.0, logit_std: float = 1.0):
# latent: [B, ...]
# timesteps: [B]
# return: [B, ...] noisy_latent, [B, ...] noise, [B] timesteps
# logit-normal sampling
u = torch.normal(mean=logit_mean, std=logit_std, size=(latent.shape[0],), device=self.sigmas.device)
u = torch.nn.functional.sigmoid(u)
step_indices = (u * self.num_train_timesteps).long()
timesteps = self.timesteps[step_indices]
sigmas = self.sigmas[step_indices].flatten()
while len(sigmas.shape) < latent.ndim:
sigmas = sigmas.unsqueeze(-1)
noise = torch.randn_like(latent)
noisy_latent = (1.0 - sigmas) * latent + sigmas * noise
return noisy_latent, noise, timesteps