Leonard Bruns
Add Vista example
d323598
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Literal, Optional, Union
import torch
from einops import rearrange, repeat
from ...util import append_dims, default
class Guider(ABC):
@abstractmethod
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
pass
def prepare_inputs(self, x, s, c, cond_mask, uc):
pass
class VanillaCFG(Guider):
def __init__(self, scale: float):
self.scale = scale
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_pred = x_u + self.scale * (x_c - x_u)
return x_pred
def prepare_inputs(self, x, s, c, cond_mask, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out, torch.cat([cond_mask] * 2)
class IdentityGuider(Guider):
def __call__(self, x: torch.Tensor, sigma: float) -> torch.Tensor:
return x
def prepare_inputs(self, x, s, c, cond_mask, uc):
c_out = dict()
for k in c:
c_out[k] = c[k]
return x, s, c_out, cond_mask
class LinearPredictionGuider(Guider):
def __init__(
self,
num_frames: int = 25,
max_scale: float = 2.5,
min_scale: float = 1.0,
additional_cond_keys: Optional[Union[list[str], str]] = None
):
self.min_scale = min_scale
self.max_scale = max_scale
self.num_frames = num_frames
self.scale = torch.linspace(min_scale, max_scale, num_frames).unsqueeze(0)
additional_cond_keys = default(additional_cond_keys, list())
if isinstance(additional_cond_keys, str):
additional_cond_keys = [additional_cond_keys]
self.additional_cond_keys = additional_cond_keys
def __call__(self, x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
x_u, x_c = x.chunk(2)
x_u = rearrange(x_u, "(b t) ... -> b t ...", t=self.num_frames)
x_c = rearrange(x_c, "(b t) ... -> b t ...", t=self.num_frames)
scale = repeat(self.scale, "1 t -> b t", b=x_u.shape[0])
scale = append_dims(scale, x_u.ndim).to(x_u.device)
return rearrange(x_u + scale * (x_c - x_u), "b t ... -> (b t) ...")
def prepare_inputs(self, x, s, c, cond_mask, uc):
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat"] + self.additional_cond_keys:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out, torch.cat([cond_mask] * 2)
class TrianglePredictionGuider(LinearPredictionGuider):
def __init__(
self,
num_frames: int = 25,
max_scale: float = 2.5,
min_scale: float = 1.0,
period: float = 1.0,
period_fusing: Literal["mean", "multiply", "max"] = "max",
additional_cond_keys: Optional[Union[list[str], str]] = None
):
super().__init__(num_frames, max_scale, min_scale, additional_cond_keys)
values = torch.linspace(0, 1, num_frames)
# constructs a triangle wave
if isinstance(period, float):
period = [period]
scales = list()
for p in period:
scales.append(self.triangle_wave(values, p))
if period_fusing == "mean":
scale = sum(scales) / len(period)
elif period_fusing == "multiply":
scale = torch.prod(torch.stack(scales), dim=0)
elif period_fusing == "max":
scale = torch.max(torch.stack(scales), dim=0).values
else:
raise NotImplementedError
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()