Spaces:
Build error
Build error
from __future__ import annotations | |
import math | |
from contextlib import nullcontext | |
from typing import Optional, Union | |
import kornia | |
import numpy as np | |
import open_clip | |
import torch | |
import torch.nn as nn | |
from einops import rearrange, repeat | |
from omegaconf import ListConfig | |
from transformers import CLIPTextModel, CLIPTokenizer | |
from ...util import ( | |
append_dims, | |
autocast, | |
count_params, | |
default, | |
disabled_train, | |
expand_dims_like, | |
instantiate_from_config, | |
) | |
from ..diffusionmodules.openaimodel import Timestep | |
class AbstractEmbModel(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self._is_trainable = None | |
self._ucg_rate = None | |
self._input_key = None | |
def is_trainable(self) -> bool: | |
return self._is_trainable | |
def ucg_rate(self) -> Union[float, torch.Tensor]: | |
return self._ucg_rate | |
def input_key(self) -> str: | |
return self._input_key | |
def is_trainable(self, value: bool): | |
self._is_trainable = value | |
def ucg_rate(self, value: Union[float, torch.Tensor]): | |
self._ucg_rate = value | |
def input_key(self, value: str): | |
self._input_key = value | |
def is_trainable(self): | |
del self._is_trainable | |
def ucg_rate(self): | |
del self._ucg_rate | |
def input_key(self): | |
del self._input_key | |
class GeneralConditioner(nn.Module): | |
OUTPUT_DIM2KEYS = {2: "vector", 3: "crossattn", 4: "concat", 5: "concat"} | |
KEY2CATDIM = {"vector": 1, "crossattn": 2, "concat": 1} | |
def __init__(self, emb_models: Union[list, ListConfig]): | |
super().__init__() | |
embedders = list() | |
for n, embconfig in enumerate(emb_models): | |
embedder = instantiate_from_config(embconfig) | |
assert isinstance( | |
embedder, AbstractEmbModel | |
), f"Embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel" | |
embedder.is_trainable = embconfig.get("is_trainable", False) | |
embedder.ucg_rate = embconfig.get("ucg_rate", 0.0) | |
if not embedder.is_trainable: | |
embedder.train = disabled_train | |
for param in embedder.parameters(): | |
param.requires_grad = False | |
embedder.eval() | |
print( | |
f"Initialized embedder #{n}: {embedder.__class__.__name__} " | |
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" | |
) | |
if "input_key" in embconfig: | |
embedder.input_key = embconfig["input_key"] | |
elif "input_keys" in embconfig: | |
embedder.input_keys = embconfig["input_keys"] | |
else: | |
raise KeyError(f"Need either `input_key` or `input_keys` for embedder {embedder.__class__.__name__}") | |
embedder.legacy_ucg_val = embconfig.get("legacy_ucg_value", None) | |
if embedder.legacy_ucg_val is not None: | |
embedder.ucg_prng = np.random.RandomState() | |
embedders.append(embedder) | |
self.embedders = nn.ModuleList(embedders) | |
def possibly_get_ucg_val(self, embedder: AbstractEmbModel, batch: dict) -> dict: | |
assert embedder.legacy_ucg_val is not None | |
p = embedder.ucg_rate | |
val = embedder.legacy_ucg_val | |
for i in range(len(batch[embedder.input_key])): | |
if embedder.ucg_prng.choice(2, p=[1 - p, p]): | |
batch[embedder.input_key][i] = val | |
return batch | |
def forward(self, batch: dict, force_zero_embeddings: Optional[list] = None) -> dict: | |
output = dict() | |
force_zero_embeddings = default(force_zero_embeddings, list()) | |
for embedder in self.embedders: | |
embedding_context = nullcontext if embedder.is_trainable else torch.no_grad | |
with embedding_context(): | |
if hasattr(embedder, "input_key") and embedder.input_key is not None: | |
if embedder.legacy_ucg_val is not None: | |
batch = self.possibly_get_ucg_val(embedder, batch) | |
if embedder.input_key in batch: | |
emb_out_1s = [] | |
# TODO this should be a parameter | |
for i in range(batch[embedder.input_key].shape[0]): | |
emb_out_1 = embedder(batch[embedder.input_key][i].unsqueeze(0)) | |
emb_out_1s.append(emb_out_1) | |
emb_out = torch.concat(emb_out_1s, 0) | |
elif embedder.add_sequence_dim: # concatenation | |
emb_dim = embedder.num_features * embedder.outdim | |
emb_out = torch.zeros((batch["cond_aug"].shape[0], 1, emb_dim), device=batch["cond_aug"].device) | |
else: # addition | |
continue | |
elif hasattr(embedder, "input_keys"): | |
emb_out = embedder(*[batch[k] for k in embedder.input_keys]) | |
assert isinstance( | |
emb_out, (torch.Tensor, list, tuple) | |
), f"Encoder outputs must be tensors or a sequence, but got {type(emb_out)}" | |
if not isinstance(emb_out, (list, tuple)): | |
emb_out = [emb_out] | |
for emb in emb_out: | |
out_key = self.OUTPUT_DIM2KEYS[emb.dim()] | |
if embedder.ucg_rate > 0.0 and embedder.legacy_ucg_val is None: | |
emb = ( | |
expand_dims_like( | |
torch.bernoulli( | |
(1.0 - embedder.ucg_rate) * torch.ones(emb.shape[0], device=emb.device) | |
), | |
emb | |
) | |
* emb | |
) | |
if hasattr(embedder, "input_key") and embedder.input_key in force_zero_embeddings: | |
emb = torch.zeros_like(emb) | |
if out_key in output: | |
if emb.shape[-1] == 768 and out_key == "vector": | |
output[out_key] += emb | |
else: | |
output[out_key] = torch.cat((output[out_key], emb), self.KEY2CATDIM[out_key]) | |
else: | |
output[out_key] = emb | |
return output | |
def get_unconditional_conditioning( | |
self, | |
batch_c: dict, | |
batch_uc: Optional[dict] = None, | |
force_cond_zero_embeddings: Optional[list[str]] = None, | |
force_uc_zero_embeddings: Optional[list[str]] = None | |
): | |
ucg_rates = list() | |
for embedder in self.embedders: | |
ucg_rates.append(embedder.ucg_rate) | |
embedder.ucg_rate = 0.0 | |
c = self(batch_c, force_cond_zero_embeddings) | |
uc = self(batch_c if batch_uc is None else batch_uc, force_uc_zero_embeddings) | |
for embedder, rate in zip(self.embedders, ucg_rates): | |
embedder.ucg_rate = rate | |
return c, uc | |
class FrozenCLIPEmbedder(AbstractEmbModel): | |
"""Uses the CLIP transformer encoder for text (from huggingface).""" | |
def __init__( | |
self, | |
# version="path_to/openai/clip-vit-large-patch14/pytorch_model.bin", | |
version="openai/clip-vit-large-patch14", | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
layer="last", | |
layer_idx=None, | |
always_return_pooled=False | |
): # clip-vit-base-patch32 | |
super().__init__() | |
assert layer in ["last", "pooled", "hidden"] | |
self.tokenizer = CLIPTokenizer.from_pretrained(version) | |
self.transformer = CLIPTextModel.from_pretrained(version) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.layer = layer | |
self.layer_idx = layer_idx | |
self.return_pooled = always_return_pooled | |
if layer == "hidden": | |
assert layer_idx is not None | |
assert 0 <= abs(layer_idx) <= 12 | |
def freeze(self): | |
self.transformer = self.transformer.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, text): | |
batch_encoding = self.tokenizer( | |
text, | |
truncation=True, | |
max_length=self.max_length, | |
return_length=True, | |
return_overflowing_tokens=False, | |
padding="max_length", | |
return_tensors="pt" | |
) | |
tokens = batch_encoding["input_ids"].to(self.device) | |
outputs = self.transformer( | |
input_ids=tokens, | |
output_hidden_states=self.layer == "hidden" | |
) | |
if self.layer == "last": | |
z = outputs.last_hidden_state | |
elif self.layer == "pooled": | |
z = outputs.pooler_output[:, None] | |
else: | |
z = outputs.hidden_states[self.layer_idx] | |
if self.return_pooled: | |
return z, outputs.pooler_output | |
else: | |
return z | |
def encode(self, text): | |
return self(text) | |
class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): | |
"""Uses the OpenCLIP vision transformer encoder for images.""" | |
def __init__( | |
self, | |
arch="ViT-H-14", | |
# version="path_to/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/open_clip_pytorch_model.bin", | |
version="laion2b_s32b_b79k", | |
device="cuda", | |
max_length=77, | |
freeze=True, | |
antialias=True, | |
ucg_rate=0.0, | |
unsqueeze_dim=False, | |
repeat_to_max_len=False, | |
num_image_crops=0, | |
output_tokens=False, | |
init_device=None | |
): | |
super().__init__() | |
model, _, _ = open_clip.create_model_and_transforms( | |
arch, | |
device=torch.device(default(init_device, "cpu")), | |
pretrained=version | |
) | |
del model.transformer | |
self.model = model | |
self.max_crops = num_image_crops | |
self.pad_to_max_len = self.max_crops > 0 | |
self.repeat_to_max_len = repeat_to_max_len and (not self.pad_to_max_len) | |
self.device = device | |
self.max_length = max_length | |
if freeze: | |
self.freeze() | |
self.antialias = antialias | |
self.register_buffer("mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) | |
self.register_buffer("std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) | |
self.ucg_rate = ucg_rate | |
self.unsqueeze_dim = unsqueeze_dim | |
self.stored_batch = None | |
self.model.visual.output_tokens = output_tokens | |
self.output_tokens = output_tokens | |
def preprocess(self, x): | |
# normalize to [0,1] | |
x = kornia.geometry.resize( | |
x, | |
(224, 224), | |
interpolation="bicubic", | |
align_corners=True, | |
antialias=self.antialias | |
) | |
x = (x + 1.0) / 2.0 | |
# renormalize according to clip | |
x = kornia.enhance.normalize(x, self.mean, self.std) | |
return x | |
def freeze(self): | |
self.model = self.model.eval() | |
for param in self.parameters(): | |
param.requires_grad = False | |
def forward(self, image, no_dropout=False): | |
z = self.encode_with_vision_transformer(image) | |
tokens = None | |
if self.output_tokens: | |
z, tokens = z[0], z[1] | |
z = z.to(image.dtype) | |
if self.ucg_rate > 0.0 and not no_dropout and not (self.max_crops > 0): | |
z = ( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) * torch.ones(z.shape[0], device=z.device) | |
)[:, None] | |
* z | |
) | |
if tokens is not None: | |
tokens = ( | |
expand_dims_like( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) * torch.ones(tokens.shape[0], device=tokens.device) | |
), | |
tokens | |
) | |
* tokens | |
) | |
if self.unsqueeze_dim: | |
z = z[:, None] | |
if self.output_tokens: | |
assert not self.repeat_to_max_len | |
assert not self.pad_to_max_len | |
return tokens, z | |
elif self.repeat_to_max_len: | |
if z.dim() == 2: | |
z_ = z[:, None] | |
else: | |
z_ = z | |
return repeat(z_, "b 1 d -> b n d", n=self.max_length), z | |
elif self.pad_to_max_len: | |
assert z.dim() == 3 | |
z_pad = torch.cat( | |
( | |
z, | |
torch.zeros(z.shape[0], self.max_length - z.shape[1], z.shape[2], device=z.device) | |
), | |
1 | |
) | |
return z_pad, z_pad[:, 0, ...] | |
else: | |
return z | |
def encode_with_vision_transformer(self, img): | |
if img.dim() == 5: | |
assert self.max_crops == img.shape[1] | |
img = rearrange(img, "b n c h w -> (b n) c h w") | |
img = self.preprocess(img) | |
if self.output_tokens: | |
assert self.model.visual.output_tokens | |
x, tokens = self.model.visual(img) | |
else: | |
assert not self.model.visual.output_tokens | |
x = self.model.visual(img) | |
tokens = None | |
if self.max_crops > 0: | |
x = rearrange(x, "(b n) d -> b n d", n=self.max_crops) | |
# drop out between 0 and all along the sequence axis | |
x = ( | |
torch.bernoulli( | |
(1.0 - self.ucg_rate) * torch.ones(x.shape[0], x.shape[1], 1, device=x.device) | |
) | |
* x | |
) | |
if tokens is not None: | |
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) | |
print( | |
f"You are running very experimental token-concat in {self.__class__.__name__}. " | |
f"Check what you are doing, and then remove this message" | |
) | |
if self.output_tokens: | |
return x, tokens | |
else: | |
return x | |
def encode(self, text): | |
return self(text) | |
class ConcatTimestepEmbedderND(AbstractEmbModel): | |
"""Embeds each dimension independently and concatenates them.""" | |
def __init__(self, outdim, num_features=None, add_sequence_dim=False): | |
super().__init__() | |
self.timestep = Timestep(outdim) | |
self.outdim = outdim | |
self.num_features = num_features | |
self.add_sequence_dim = add_sequence_dim | |
def forward(self, x): | |
if x.ndim == 1: | |
x = x[:, None] | |
assert len(x.shape) == 2 | |
b, dims = x.shape[0], x.shape[1] | |
assert dims == self.num_features or self.num_features is None | |
x = rearrange(x, "b d -> (b d)") | |
emb = self.timestep(x) | |
emb = rearrange(emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) | |
if self.add_sequence_dim: | |
emb = emb[:, None] | |
return emb | |
class VideoPredictionEmbedderWithEncoder(AbstractEmbModel): | |
def __init__( | |
self, | |
n_cond_frames: int, | |
n_copies: int, | |
encoder_config: dict, | |
sigma_sampler_config: Optional[dict] = None, | |
sigma_cond_config: Optional[dict] = None, | |
is_ae: bool = False, | |
scale_factor: float = 1.0, | |
disable_encoder_autocast: bool = False, | |
en_and_decode_n_samples_a_time: Optional[int] = None | |
): | |
super().__init__() | |
self.n_cond_frames = n_cond_frames | |
self.n_copies = n_copies | |
self.encoder = instantiate_from_config(encoder_config) | |
self.sigma_sampler = ( | |
instantiate_from_config(sigma_sampler_config) | |
if sigma_sampler_config is not None | |
else None | |
) | |
self.sigma_cond = ( | |
instantiate_from_config(sigma_cond_config) | |
if sigma_cond_config is not None | |
else None | |
) | |
self.is_ae = is_ae | |
self.scale_factor = scale_factor | |
self.disable_encoder_autocast = disable_encoder_autocast | |
self.en_and_decode_n_samples_a_time = en_and_decode_n_samples_a_time | |
self.skip_encode = False | |
def forward( | |
self, vid: torch.Tensor | |
) -> Union[ | |
torch.Tensor, | |
tuple[torch.Tensor, torch.Tensor], | |
tuple[torch.Tensor, dict], | |
tuple[tuple[torch.Tensor, torch.Tensor], dict] | |
]: | |
if self.skip_encode: | |
return vid | |
else: | |
if self.sigma_sampler is not None: | |
bs = vid.shape[0] // self.n_cond_frames | |
sigmas = self.sigma_sampler(bs).to(vid.device) | |
if self.sigma_cond is not None: | |
sigma_cond = self.sigma_cond(sigmas) | |
sigma_cond = repeat(sigma_cond, "b d -> (b t) d", t=self.n_copies) | |
sigmas = repeat(sigmas, "b -> (b t)", t=self.n_cond_frames) | |
noise = torch.randn_like(vid) | |
vid = vid + noise * append_dims(sigmas, vid.ndim) | |
with torch.autocast("cuda", enabled=not self.disable_encoder_autocast): | |
n_samples = default(self.en_and_decode_n_samples_a_time, vid.shape[0]) | |
n_rounds = math.ceil(vid.shape[0] / n_samples) | |
all_out = list() | |
for n in range(n_rounds): | |
if self.is_ae: | |
out = self.encoder.encode(vid[n * n_samples: (n + 1) * n_samples]) | |
else: | |
out = self.encoder(vid[n * n_samples: (n + 1) * n_samples]) | |
all_out.append(out) | |
vid = torch.cat(all_out, dim=0) | |
vid *= self.scale_factor | |
vid = rearrange(vid, "(b t) c h w -> b () (t c) h w", t=self.n_cond_frames) | |
vid = repeat(vid, "b 1 c h w -> (b t) c h w", t=self.n_copies) | |
if self.sigma_cond is not None: | |
return vid, sigma_cond | |
else: | |
return vid | |
class FrozenOpenCLIPImagePredictionEmbedder(AbstractEmbModel): | |
def __init__(self, open_clip_embedding_config: dict, n_cond_frames: int, n_copies: int): | |
super().__init__() | |
self.n_cond_frames = n_cond_frames | |
self.n_copies = n_copies | |
self.open_clip = instantiate_from_config(open_clip_embedding_config) | |
def forward(self, vid): | |
vid = self.open_clip(vid) | |
vid = rearrange(vid, "(b t) d -> b t d", t=self.n_cond_frames) | |
vid = repeat(vid, "b t d -> (b s) t d", s=self.n_copies) | |
return vid | |