Spaces:
Running
on
Zero
Running
on
Zero
# MIT License | |
# Copyright (c) 2023 Alexander Tong | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# The above copyright notice and this permission notice shall be included in all | |
# copies or substantial portions of the Software. | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
# Copyright (c) [2023] [Alexander Tong] | |
# Copyright (c) [2025] [Ziyue Jiang] | |
# SPDX-License-Identifier: MIT | |
# This file has been modified by Ziyue Jiang on 2025/03/19 | |
# Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE. | |
# This modified file is released under the same license. | |
import math | |
import torch | |
from typing import Union | |
from torch.distributions import LogisticNormal | |
class LogitNormalTrainingTimesteps: | |
def __init__(self, T=1000.0, loc=0.0, scale=1.0): | |
assert T > 0 | |
self.T = T | |
self.dist = LogisticNormal(loc, scale) | |
def sample(self, size, device): | |
t = self.dist.sample(size)[..., 0].to(device) | |
return t | |
def pad_t_like_x(t, x): | |
"""Function to reshape the time vector t by the number of dimensions of x. | |
Parameters | |
---------- | |
x : Tensor, shape (bs, *dim) | |
represents the source minibatch | |
t : FloatTensor, shape (bs) | |
Returns | |
------- | |
t : Tensor, shape (bs, number of x dimensions) | |
Example | |
------- | |
x: Tensor (bs, C, W, H) | |
t: Vector (bs) | |
pad_t_like_x(t, x): Tensor (bs, 1, 1, 1) | |
""" | |
if isinstance(t, (float, int)): | |
return t | |
return t.reshape(-1, *([1] * (x.dim() - 1))) | |
class ConditionalFlowMatcher: | |
"""Base class for conditional flow matching methods. This class implements the independent | |
conditional flow matching methods from [1] and serves as a parent class for all other flow | |
matching methods. | |
It implements: | |
- Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function | |
- conditional flow matching ut(x1|x0) = x1 - x0 | |
- score function $\nabla log p_t(x|x0, x1)$ | |
""" | |
def __init__(self, sigma: Union[float, int] = 0.0): | |
r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. | |
Parameters | |
---------- | |
sigma : Union[float, int] | |
""" | |
self.sigma = sigma | |
self.time_sampler = LogitNormalTrainingTimesteps() | |
def compute_mu_t(self, x0, x1, t): | |
""" | |
Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. | |
Parameters | |
---------- | |
x0 : Tensor, shape (bs, *dim) | |
represents the source minibatch | |
x1 : Tensor, shape (bs, *dim) | |
represents the target minibatch | |
t : FloatTensor, shape (bs) | |
Returns | |
------- | |
mean mu_t: t * x1 + (1 - t) * x0 | |
References | |
---------- | |
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
""" | |
t = pad_t_like_x(t, x0) | |
return t * x1 + (1 - t) * x0 | |
def compute_sigma_t(self, t): | |
""" | |
Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. | |
Parameters | |
---------- | |
t : FloatTensor, shape (bs) | |
Returns | |
------- | |
standard deviation sigma | |
References | |
---------- | |
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
""" | |
del t | |
return self.sigma | |
def sample_xt(self, x0, x1, t, epsilon): | |
""" | |
Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. | |
Parameters | |
---------- | |
x0 : Tensor, shape (bs, *dim) | |
represents the source minibatch | |
x1 : Tensor, shape (bs, *dim) | |
represents the target minibatch | |
t : FloatTensor, shape (bs) | |
epsilon : Tensor, shape (bs, *dim) | |
noise sample from N(0, 1) | |
Returns | |
------- | |
xt : Tensor, shape (bs, *dim) | |
References | |
---------- | |
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
""" | |
mu_t = self.compute_mu_t(x0, x1, t) | |
sigma_t = self.compute_sigma_t(t) | |
sigma_t = pad_t_like_x(sigma_t, x0) | |
return mu_t + sigma_t * epsilon | |
def compute_conditional_flow(self, x0, x1, t, xt): | |
""" | |
Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. | |
Parameters | |
---------- | |
x0 : Tensor, shape (bs, *dim) | |
represents the source minibatch | |
x1 : Tensor, shape (bs, *dim) | |
represents the target minibatch | |
t : FloatTensor, shape (bs) | |
xt : Tensor, shape (bs, *dim) | |
represents the samples drawn from probability path pt | |
Returns | |
------- | |
ut : conditional vector field ut(x1|x0) = x1 - x0 | |
References | |
---------- | |
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
""" | |
del t, xt | |
return x1 - x0 | |
def sample_noise_like(self, x): | |
return torch.randn_like(x) | |
def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): | |
""" | |
Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) | |
and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. | |
Parameters | |
---------- | |
x0 : Tensor, shape (bs, *dim) | |
represents the source minibatch | |
x1 : Tensor, shape (bs, *dim) | |
represents the target minibatch | |
(optionally) t : Tensor, shape (bs) | |
represents the time levels | |
if None, drawn from uniform [0,1] | |
return_noise : bool | |
return the noise sample epsilon | |
Returns | |
------- | |
t : FloatTensor, shape (bs) | |
xt : Tensor, shape (bs, *dim) | |
represents the samples drawn from probability path pt | |
ut : conditional vector field ut(x1|x0) = x1 - x0 | |
(optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon | |
References | |
---------- | |
[1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. | |
""" | |
if t is None: | |
# t = torch.rand(x0.shape[0]).type_as(x0) | |
t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0) | |
assert len(t) == x0.shape[0], "t has to have batch size dimension" | |
eps = self.sample_noise_like(x0) | |
xt = self.sample_xt(x0, x1, t, eps) | |
ut = self.compute_conditional_flow(x0, x1, t, xt) | |
if return_noise: | |
return t, xt, ut, eps | |
else: | |
return t, xt, ut | |
def compute_lambda(self, t): | |
"""Compute the lambda function, see Eq.(23) [3]. | |
Parameters | |
---------- | |
t : FloatTensor, shape (bs) | |
Returns | |
------- | |
lambda : score weighting function | |
References | |
---------- | |
[4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al. | |
""" | |
sigma_t = self.compute_sigma_t(t) | |
return 2 * sigma_t / (self.sigma**2 + 1e-8) | |
class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher): | |
"""Albergo et al. 2023 trigonometric interpolants class. This class inherits the | |
ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in | |
order to compute [3]'s trigonometric interpolants. | |
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. | |
""" | |
def compute_mu_t(self, x0, x1, t): | |
r"""Compute the mean of the probability path (Eq.5) from [3]. | |
Parameters | |
---------- | |
x0 : Tensor, shape (bs, *dim) | |
represents the source minibatch | |
x1 : Tensor, shape (bs, *dim) | |
represents the target minibatch | |
t : FloatTensor, shape (bs) | |
Returns | |
------- | |
mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1 | |
References | |
---------- | |
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. | |
""" | |
t = pad_t_like_x(t, x0) | |
return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1 | |
def compute_conditional_flow(self, x0, x1, t, xt): | |
r"""Compute the conditional vector field similar to [3]. | |
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0), | |
see Eq.(21) [3]. | |
Parameters | |
---------- | |
x0 : Tensor, shape (bs, *dim) | |
represents the source minibatch | |
x1 : Tensor, shape (bs, *dim) | |
represents the target minibatch | |
t : FloatTensor, shape (bs) | |
xt : Tensor, shape (bs, *dim) | |
represents the samples drawn from probability path pt | |
Returns | |
------- | |
ut : conditional vector field | |
ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0) | |
References | |
---------- | |
[3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. | |
""" | |
del xt | |
t = pad_t_like_x(t, x0) | |
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0) | |