Leonard Bruns
Add Vista example
d323598
from __future__ import annotations
import math
from contextlib import nullcontext
from typing import Optional, Union
import kornia
import numpy as np
import open_clip
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import ListConfig
from transformers import CLIPTextModel, CLIPTokenizer
from ...util import (
append_dims,
autocast,
count_params,
default,
disabled_train,
expand_dims_like,
instantiate_from_config,
)
from ..diffusionmodules.openaimodel import Timestep
class AbstractEmbModel(nn.Module):
def __init__(self):
super().__init__()
self._is_trainable = None
self._ucg_rate = None
self._input_key = None
@property
def is_trainable(self) -> bool:
return self._is_trainable
@property
def ucg_rate(self) -> Union[float, torch.Tensor]:
return self._ucg_rate
@property
def input_key(self) -> str:
return self._input_key
@is_trainable.setter
def is_trainable(self, value: bool):
self._is_trainable = value
@ucg_rate.setter
def ucg_rate(self, value: Union[float, torch.Tensor]):
self._ucg_rate = value
@input_key.setter
def input_key(self, value: str):
self._input_key = value
@is_trainable.deleter
def is_trainable(self):
del self._is_trainable
@ucg_rate.deleter
def ucg_rate(self):
del self._ucg_rate
@input_key.deleter
def input_key(self):
del self._input_key
class GeneralConditioner(nn.Module):
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"}
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1}
def __init__(self, emb_models: Union[list, ListConfig]):
super().__init__()
embedders = list()
for n, embconfig in enumerate(emb_models):
embedder = instantiate_from_config(embconfig)
assert isinstance(
embedder, AbstractEmbModel
), f"Embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
embedder.is_trainable = embconfig.get("is_trainable", False)
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0)
if not embedder.is_trainable:
embedder.train = disabled_train
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
if "input_key" in embconfig:
embedder.input_key = embconfig["input_key"]
elif "input_keys" in embconfig:
embedder.input_keys = embconfig["input_keys"]
else:
raise KeyError(f"Need either `input_key` or `input_keys` for embedder {embedder.__class__.__name__}")
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None)
if embedder.legacy_ucg_val is not None:
embedder.ucg_prng = np.random.RandomState()
embedders.append(embedder)
self.embedders = nn.ModuleList(embedders)
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: dict) -> dict:
assert embedder.legacy_ucg_val is not None
p = embedder.ucg_rate
val = embedder.legacy_ucg_val
for i in range(len(batch[embedder.input_key])):
if embedder.ucg_prng.choice(2, p=[1 - p, p]):
batch[embedder.input_key][i] = val
return batch
def forward(self, batch: dict, force_zero_embeddings: Optional[list] = None) -> dict:
output = dict()
force_zero_embeddings = default(force_zero_embeddings, list())
for embedder in self.embedders:
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad
with embedding_context():
if hasattr(embedder, "input_key") and embedder.input_key is not None:
if embedder.legacy_ucg_val is not None:
batch = self.possibly_get_ucg_val(embedder, batch)
if embedder.input_key in batch:
emb_out_1s = []
# TODO this should be a parameter
for i in range(batch[embedder.input_key].shape[0]):
emb_out_1 = embedder(batch[embedder.input_key][i].unsqueeze(0))
emb_out_1s.append(emb_out_1)
emb_out = torch.concat(emb_out_1s, 0)
elif embedder.add_sequence_dim: # concatenation
emb_dim = embedder.num_features * embedder.outdim
emb_out = torch.zeros((batch["cond_aug"].shape[0], 1, emb_dim), device=batch["cond_aug"].device)
else: # addition
continue
elif hasattr(embedder, "input_keys"):
emb_out = embedder(*[batch[k] for k in embedder.input_keys])
assert isinstance(
emb_out, (torch.Tensor, list, tuple)
), f"Encoder outputs must be tensors or a sequence, but got {type(emb_out)}"
if not isinstance(emb_out, (list, tuple)):
emb_out = [emb_out]
for emb in emb_out:
out_key = self.OUTPUT_DIM2KEYS[emb.dim()]
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None:
emb = (
expand_dims_like(
torch.bernoulli(
(1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device)
),
emb
)
* emb
)
if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings:
emb = torch.zeros_like(emb)
if out_key in output:
if emb.shape[-1] == 768 and out_key == "vector":
output[out_key] += emb
else:
output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key])
else:
output[out_key] = emb
return output
def get_unconditional_conditioning(
self,
batch_c: dict,
batch_uc: Optional[dict] = None,
force_cond_zero_embeddings: Optional[list[str]] = None,
force_uc_zero_embeddings: Optional[list[str]] = None
):
ucg_rates = list()
for embedder in self.embedders:
ucg_rates.append(embedder.ucg_rate)
embedder.ucg_rate = 0.0
c = self(batch_c, force_cond_zero_embeddings)
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings)
for embedder, rate in zip(self.embedders, ucg_rates):
embedder.ucg_rate = rate
return c, uc
class FrozenCLIPEmbedder(AbstractEmbModel):
"""Uses the CLIP transformer encoder for text (from huggingface)."""
def __init__(
self,
# version="path_to/openai/clip-vit-large-patch14/pytorch_model.bin",
version="openai/clip-vit-large-patch14",
device="cuda",
max_length=77,
freeze=True,
layer="last",
layer_idx=None,
always_return_pooled=False
): # clip-vit-base-patch32
super().__init__()
assert layer in ["last", "pooled", "hidden"]
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
self.return_pooled = always_return_pooled
if layer == "hidden":
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt"
)
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(
input_ids=tokens,
output_hidden_states=self.layer == "hidden"
)
if self.layer == "last":
z = outputs.last_hidden_state
elif self.layer == "pooled":
z = outputs.pooler_output[:, None]
else:
z = outputs.hidden_states[self.layer_idx]
if self.return_pooled:
return z, outputs.pooler_output
else:
return z
def encode(self, text):
return self(text)
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
"""Uses the OpenCLIP vision transformer encoder for images."""
def __init__(
self,
arch="ViT-H-14",
# version="path_to/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin",
version="laion2b_s32b_b79k",
device="cuda",
max_length=77,
freeze=True,
antialias=True,
ucg_rate=0.0,
unsqueeze_dim=False,
repeat_to_max_len=False,
num_image_crops=0,
output_tokens=False,
init_device=None
):
super().__init__()
model, _, _ = open_clip.create_model_and_transforms(
arch,
device=torch.device(default(init_device, "cpu")),
pretrained=version
)
del model.transformer
self.model = model
self.max_crops = num_image_crops
self.pad_to_max_len = self.max_crops > 0
self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.antialias = antialias
self.register_buffer("mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer("std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.ucg_rate = ucg_rate
self.unsqueeze_dim = unsqueeze_dim
self.stored_batch = None
self.model.visual.output_tokens = output_tokens
self.output_tokens = output_tokens
def preprocess(self, x):
# normalize to [0,1]
x = kornia.geometry.resize(
x,
(224, 224),
interpolation="bicubic",
align_corners=True,
antialias=self.antialias
)
x = (x + 1.0) / 2.0
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
@autocast
def forward(self, image, no_dropout=False):
z = self.encode_with_vision_transformer(image)
tokens = None
if self.output_tokens:
z, tokens = z[0], z[1]
z = z.to(image.dtype)
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0):
z = (
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device)
)[:, None]
* z
)
if tokens is not None:
tokens = (
expand_dims_like(
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(tokens.shape[0], device=tokens.device)
),
tokens
)
* tokens
)
if self.unsqueeze_dim:
z = z[:, None]
if self.output_tokens:
assert not self.repeat_to_max_len
assert not self.pad_to_max_len
return tokens, z
elif self.repeat_to_max_len:
if z.dim() == 2:
z_ = z[:, None]
else:
z_ = z
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z
elif self.pad_to_max_len:
assert z.dim() == 3
z_pad = torch.cat(
(
z,
torch.zeros(z.shape[0], self.max_length - z.shape[1], z.shape[2], device=z.device)
),
1
)
return z_pad, z_pad[:, 0, ...]
else:
return z
def encode_with_vision_transformer(self, img):
if img.dim() == 5:
assert self.max_crops == img.shape[1]
img = rearrange(img, "b n c h w -> (b n) c h w")
img = self.preprocess(img)
if self.output_tokens:
assert self.model.visual.output_tokens
x, tokens = self.model.visual(img)
else:
assert not self.model.visual.output_tokens
x = self.model.visual(img)
tokens = None
if self.max_crops > 0:
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops)
# drop out between 0 and all along the sequence axis
x = (
torch.bernoulli(
(1.0 - self.ucg_rate) * torch.ones(x.shape[0], x.shape[1], 1, device=x.device)
)
* x
)
if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
print(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message"
)
if self.output_tokens:
return x, tokens
else:
return x
def encode(self, text):
return self(text)
class ConcatTimestepEmbedderND(AbstractEmbModel):
"""Embeds each dimension independently and concatenates them."""
def __init__(self, outdim, num_features=None, add_sequence_dim=False):
super().__init__()
self.timestep = Timestep(outdim)
self.outdim = outdim
self.num_features = num_features
self.add_sequence_dim = add_sequence_dim
def forward(self, x):
if x.ndim == 1:
x = x[:, None]
assert len(x.shape) == 2
b, dims = x.shape[0], x.shape[1]
assert dims == self.num_features or self.num_features is None
x = rearrange(x, "b d -> (b d)")
emb = self.timestep(x)
emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
if self.add_sequence_dim:
emb = emb[:, None]
return emb
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel):
def __init__(
self,
n_cond_frames: int,
n_copies: int,
encoder_config: dict,
sigma_sampler_config: Optional[dict] = None,
sigma_cond_config: Optional[dict] = None,
is_ae: bool = False,
scale_factor: float = 1.0,
disable_encoder_autocast: bool = False,
en_and_decode_n_samples_a_time: Optional[int] = None
):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.encoder = instantiate_from_config(encoder_config)
self.sigma_sampler = (
instantiate_from_config(sigma_sampler_config)
if sigma_sampler_config is not None
else None
)
self.sigma_cond = (
instantiate_from_config(sigma_cond_config)
if sigma_cond_config is not None
else None
)
self.is_ae = is_ae
self.scale_factor = scale_factor
self.disable_encoder_autocast = disable_encoder_autocast
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time
self.skip_encode = False
def forward(
self, vid: torch.Tensor
) -> Union[
torch.Tensor,
tuple[torch.Tensor, torch.Tensor],
tuple[torch.Tensor, dict],
tuple[tuple[torch.Tensor, torch.Tensor], dict]
]:
if self.skip_encode:
return vid
else:
if self.sigma_sampler is not None:
bs = vid.shape[0] // self.n_cond_frames
sigmas = self.sigma_sampler(bs).to(vid.device)
if self.sigma_cond is not None:
sigma_cond = self.sigma_cond(sigmas)
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies)
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames)
noise = torch.randn_like(vid)
vid = vid + noise * append_dims(sigmas, vid.ndim)
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast):
n_samples = default(self.en_and_decode_n_samples_a_time, vid.shape[0])
n_rounds = math.ceil(vid.shape[0] / n_samples)
all_out = list()
for n in range(n_rounds):
if self.is_ae:
out = self.encoder.encode(vid[n * n_samples: (n + 1) * n_samples])
else:
out = self.encoder(vid[n * n_samples: (n + 1) * n_samples])
all_out.append(out)
vid = torch.cat(all_out, dim=0)
vid *= self.scale_factor
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames)
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies)
if self.sigma_cond is not None:
return vid, sigma_cond
else:
return vid
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel):
def __init__(self, open_clip_embedding_config: dict, n_cond_frames: int, n_copies: int):
super().__init__()
self.n_cond_frames = n_cond_frames
self.n_copies = n_copies
self.open_clip = instantiate_from_config(open_clip_embedding_config)
def forward(self, vid):
vid = self.open_clip(vid)
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames)
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies)
return vid