# MIT License # Copyright (c) 2023 Alexander Tong # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # Copyright (c) [2023] [Alexander Tong] # Copyright (c) [2025] [Ziyue Jiang] # SPDX-License-Identifier: MIT # This file has been modified by Ziyue Jiang on 2025/03/19 # Original file was released under MIT, with the full license text # available at https://github.com/atong01/conditional-flow-matching/blob/1.0.7/LICENSE. # This modified file is released under the same license. import math import torch from typing import Union from torch.distributions import LogisticNormal class LogitNormalTrainingTimesteps: def __init__(self, T=1000.0, loc=0.0, scale=1.0): assert T > 0 self.T = T self.dist = LogisticNormal(loc, scale) def sample(self, size, device): t = self.dist.sample(size)[..., 0].to(device) return t def pad_t_like_x(t, x): """Function to reshape the time vector t by the number of dimensions of x. Parameters ---------- x : Tensor, shape (bs, *dim) represents the source minibatch t : FloatTensor, shape (bs) Returns ------- t : Tensor, shape (bs, number of x dimensions) Example ------- x: Tensor (bs, C, W, H) t: Vector (bs) pad_t_like_x(t, x): Tensor (bs, 1, 1, 1) """ if isinstance(t, (float, int)): return t return t.reshape(-1, *([1] * (x.dim() - 1))) class ConditionalFlowMatcher: """Base class for conditional flow matching methods. This class implements the independent conditional flow matching methods from [1] and serves as a parent class for all other flow matching methods. It implements: - Drawing data from gaussian probability path N(t * x1 + (1 - t) * x0, sigma) function - conditional flow matching ut(x1|x0) = x1 - x0 - score function $\nabla log p_t(x|x0, x1)$ """ def __init__(self, sigma: Union[float, int] = 0.0): r"""Initialize the ConditionalFlowMatcher class. It requires the hyper-parameter $\sigma$. Parameters ---------- sigma : Union[float, int] """ self.sigma = sigma self.time_sampler = LogitNormalTrainingTimesteps() def compute_mu_t(self, x0, x1, t): """ Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. Parameters ---------- x0 : Tensor, shape (bs, *dim) represents the source minibatch x1 : Tensor, shape (bs, *dim) represents the target minibatch t : FloatTensor, shape (bs) Returns ------- mean mu_t: t * x1 + (1 - t) * x0 References ---------- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. """ t = pad_t_like_x(t, x0) return t * x1 + (1 - t) * x0 def compute_sigma_t(self, t): """ Compute the standard deviation of the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. Parameters ---------- t : FloatTensor, shape (bs) Returns ------- standard deviation sigma References ---------- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. """ del t return self.sigma def sample_xt(self, x0, x1, t, epsilon): """ Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma), see (Eq.14) [1]. Parameters ---------- x0 : Tensor, shape (bs, *dim) represents the source minibatch x1 : Tensor, shape (bs, *dim) represents the target minibatch t : FloatTensor, shape (bs) epsilon : Tensor, shape (bs, *dim) noise sample from N(0, 1) Returns ------- xt : Tensor, shape (bs, *dim) References ---------- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. """ mu_t = self.compute_mu_t(x0, x1, t) sigma_t = self.compute_sigma_t(t) sigma_t = pad_t_like_x(sigma_t, x0) return mu_t + sigma_t * epsilon def compute_conditional_flow(self, x0, x1, t, xt): """ Compute the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. Parameters ---------- x0 : Tensor, shape (bs, *dim) represents the source minibatch x1 : Tensor, shape (bs, *dim) represents the target minibatch t : FloatTensor, shape (bs) xt : Tensor, shape (bs, *dim) represents the samples drawn from probability path pt Returns ------- ut : conditional vector field ut(x1|x0) = x1 - x0 References ---------- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. """ del t, xt return x1 - x0 def sample_noise_like(self, x): return torch.randn_like(x) def sample_location_and_conditional_flow(self, x0, x1, t=None, return_noise=False): """ Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma)) and the conditional vector field ut(x1|x0) = x1 - x0, see Eq.(15) [1]. Parameters ---------- x0 : Tensor, shape (bs, *dim) represents the source minibatch x1 : Tensor, shape (bs, *dim) represents the target minibatch (optionally) t : Tensor, shape (bs) represents the time levels if None, drawn from uniform [0,1] return_noise : bool return the noise sample epsilon Returns ------- t : FloatTensor, shape (bs) xt : Tensor, shape (bs, *dim) represents the samples drawn from probability path pt ut : conditional vector field ut(x1|x0) = x1 - x0 (optionally) eps: Tensor, shape (bs, *dim) such that xt = mu_t + sigma_t * epsilon References ---------- [1] Improving and Generalizing Flow-Based Generative Models with minibatch optimal transport, Preprint, Tong et al. """ if t is None: # t = torch.rand(x0.shape[0]).type_as(x0) t = self.time_sampler.sample([x0.shape[0]], x0.device).type_as(x0) assert len(t) == x0.shape[0], "t has to have batch size dimension" eps = self.sample_noise_like(x0) xt = self.sample_xt(x0, x1, t, eps) ut = self.compute_conditional_flow(x0, x1, t, xt) if return_noise: return t, xt, ut, eps else: return t, xt, ut def compute_lambda(self, t): """Compute the lambda function, see Eq.(23) [3]. Parameters ---------- t : FloatTensor, shape (bs) Returns ------- lambda : score weighting function References ---------- [4] Simulation-free Schrodinger bridges via score and flow matching, Preprint, Tong et al. """ sigma_t = self.compute_sigma_t(t) return 2 * sigma_t / (self.sigma**2 + 1e-8) class VariancePreservingConditionalFlowMatcher(ConditionalFlowMatcher): """Albergo et al. 2023 trigonometric interpolants class. This class inherits the ConditionalFlowMatcher and override the compute_mu_t and compute_conditional_flow functions in order to compute [3]'s trigonometric interpolants. [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. """ def compute_mu_t(self, x0, x1, t): r"""Compute the mean of the probability path (Eq.5) from [3]. Parameters ---------- x0 : Tensor, shape (bs, *dim) represents the source minibatch x1 : Tensor, shape (bs, *dim) represents the target minibatch t : FloatTensor, shape (bs) Returns ------- mean mu_t: cos(pi t/2)x0 + sin(pi t/2)x1 References ---------- [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. """ t = pad_t_like_x(t, x0) return torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1 def compute_conditional_flow(self, x0, x1, t, xt): r"""Compute the conditional vector field similar to [3]. ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(pi*t/2) x0), see Eq.(21) [3]. Parameters ---------- x0 : Tensor, shape (bs, *dim) represents the source minibatch x1 : Tensor, shape (bs, *dim) represents the target minibatch t : FloatTensor, shape (bs) xt : Tensor, shape (bs, *dim) represents the samples drawn from probability path pt Returns ------- ut : conditional vector field ut(x1|x0) = pi/2 (cos(pi*t/2) x1 - sin(\pi*t/2) x0) References ---------- [3] Stochastic Interpolants: A Unifying Framework for Flows and Diffusions, Albergo et al. """ del xt t = pad_t_like_x(t, x0) return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)