ZiyueJiang's picture
first commit for huggingface space
593f3bc
# MIT License
# Copyright (c) 2023 Alexander Tong
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# Copyright (c) [2023] [Alexander Tong]
# Copyright (c) [2025] [Ziyue Jiang]
# SPDX-License-Identifier: MIT
# This file has been modified by Ziyue Jiang on 2025/03/19
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE.
# This modified file is released under the same license.
import math
import torch
from typing import Union
from torch.distributions import LogisticNormal
class LogitNormalTrainingTimesteps:
def __init__(self, T=1000.0, loc=0.0, scale=1.0):
assert T > 0
self.T = T
self.dist = LogisticNormal(loc, scale)
def sample(self, size, device):
t = self.dist.sample(size)[..., 0].to(device)
return t
def pad_t_like_x(t, x):
"""Function to reshape the time vector t by the number of dimensions of x.
Parameters
----------
x : Tensor, shape (bs, *dim)
represents the source minibatch
t : FloatTensor, shape (bs)
Returns
-------
t : Tensor, shape (bs, number of x dimensions)
Example
-------
x: Tensor (bs, C, W, H)
t: Vector (bs)
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1)
"""
if isinstance(t, (float, int)):
return t
return t.reshape(-1, *([1] * (x.dim() - 1)))
class ConditionalFlowMatcher:
"""Base class for conditional flow matching methods. This class implements the independent
conditional flow matching methods from [1] and serves as a parent class for all other flow
matching methods.
It implements:
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function
- conditional flow matching ut(x1|x0) = x1 - x0
- score function $\nabla log p_t(x|x0, x1)$
"""
def __init__(self, sigma: Union[float, int] = 0.0):
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$.
Parameters
----------
sigma : Union[float, int]
"""
self.sigma = sigma
self.time_sampler = LogitNormalTrainingTimesteps()
def compute_mu_t(self, x0, x1, t):
"""
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: t * x1 + (1 - t) * x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
t = pad_t_like_x(t, x0)
return t * x1 + (1 - t) * x0
def compute_sigma_t(self, t):
"""
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
standard deviation sigma
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t
return self.sigma
def sample_xt(self, x0, x1, t, epsilon):
"""
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
epsilon : Tensor, shape (bs, *dim)
noise sample from N(0, 1)
Returns
-------
xt : Tensor, shape (bs, *dim)
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
mu_t = self.compute_mu_t(x0, x1, t)
sigma_t = self.compute_sigma_t(t)
sigma_t = pad_t_like_x(sigma_t, x0)
return mu_t + sigma_t * epsilon
def compute_conditional_flow(self, x0, x1, t, xt):
"""
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field ut(x1|x0) = x1 - x0
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
del t, xt
return x1 - x0
def sample_noise_like(self, x):
return torch.randn_like(x)
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False):
"""
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
(optionally) t : Tensor, shape (bs)
represents the time levels
if None, drawn from uniform [0,1]
return_noise : bool
return the noise sample epsilon
Returns
-------
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
ut : conditional vector field ut(x1|x0) = x1 - x0
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon
References
----------
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al.
"""
if t is None:
# t = torch.rand(x0.shape[0]).type_as(x0)
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0)
assert len(t) == x0.shape[0], "t has to have batch size dimension"
eps = self.sample_noise_like(x0)
xt = self.sample_xt(x0, x1, t, eps)
ut = self.compute_conditional_flow(x0, x1, t, xt)
if return_noise:
return t, xt, ut, eps
else:
return t, xt, ut
def compute_lambda(self, t):
"""Compute the lambda function, see Eq.(23) [3].
Parameters
----------
t : FloatTensor, shape (bs)
Returns
-------
lambda : score weighting function
References
----------
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al.
"""
sigma_t = self.compute_sigma_t(t)
return 2 * sigma_t / (self.sigma**2 + 1e-8)
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher):
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in
order to compute [3]'s trigonometric interpolants.
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
def compute_mu_t(self, x0, x1, t):
r"""Compute the mean of the probability path (Eq.5) from [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
Returns
-------
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
t = pad_t_like_x(t, x0)
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
def compute_conditional_flow(self, x0, x1, t, xt):
r"""Compute the conditional vector field similar to [3].
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0),
see Eq.(21) [3].
Parameters
----------
x0 : Tensor, shape (bs, *dim)
represents the source minibatch
x1 : Tensor, shape (bs, *dim)
represents the target minibatch
t : FloatTensor, shape (bs)
xt : Tensor, shape (bs, *dim)
represents the samples drawn from probability path pt
Returns
-------
ut : conditional vector field
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0)
References
----------
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al.
"""
del xt
t = pad_t_like_x(t, x0)
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)