jbilcke-hf's picture
jbilcke-hf HF Staff
upgrading finetrainers (and losing my extra code + improvements)
80ebcb3
raw
history blame
317 Bytes
import torch
def flow_match_xt(x0: torch.Tensor, n: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
r"""Forward process of flow matching."""
return (1.0 - t) * x0 + t * n
def flow_match_target(n: torch.Tensor, x0: torch.Tensor) -> torch.Tensor:
r"""Loss target for flow matching."""
return n - x0