Spaces:
Running
Running
import torch | |
def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |
r"""Forward process of flow matching.""" | |
return (1.0 - t) * x0 + t * n | |
def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor: | |
r"""Loss target for flow matching.""" | |
return n - x0 | |