EmaadKhwaja
file upload
5d2263b
raw
history blame
3.33 kB
from math import sqrt, log
from omegaconf import OmegaConf
import importlib
import torch
from torch import nn
import torch.nn.functional as F
from einops import rearrange
# helpers methods
def load_model(path):
with open(path, "rb") as f:
return torch.load(f, map_location=torch.device("cpu"))
def map_pixels(x, eps=0.1):
return (1 - 2 * eps) * x + eps
def unmap_pixels(x, eps=0.1):
return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)
def make_contiguous(module):
with torch.no_grad():
for param in module.parameters():
param.set_(param.contiguous())
# VQGAN from Taming Transformers paper
# https://arxiv.org/abs/2012.09841
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
class VQGanVAE(nn.Module):
def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels=1):
super().__init__()
assert vqgan_config_path is not None
model_path = vqgan_model_path
config_path = vqgan_config_path
config = OmegaConf.load(config_path)
model = instantiate_from_config(config["model"])
if vqgan_model_path:
state = torch.load(model_path, map_location="cpu")["state_dict"]
model.load_state_dict(state, strict=True)
print(f"Loaded VQGAN from {model_path} and {config_path}")
self.model = model
# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
f = (
config.model.params.ddconfig.resolution
/ config.model.params.ddconfig.attn_resolutions[0]
)
self.num_layers = int(log(f) / log(2))
self.image_size = config.model.params.ddconfig.resolution
self.num_tokens = config.model.params.n_embed
# self.is_gumbel = isinstance(self.model, GumbelVQ)
self.is_gumbel = False
self.channels = config.model.params.ddconfig.in_channels
def encode(self, img):
return self.model.encode(img)
def get_codebook_indices(self, img):
b = img.shape[0]
# img = (2 * img) - 1
_, _, [_, _, indices] = self.encode(img)
if self.is_gumbel:
return rearrange(indices, "b h w -> b (h w)", b=b)
return rearrange(indices, "(b n) -> b n", b=b)
def decode(self, img_seq):
b, n = img_seq.shape
one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float()
z = (
one_hot_indices @ self.model.quantize.embed.weight
if self.is_gumbel
else (one_hot_indices @ self.model.quantize.embedding.weight)
)
z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n)))
img = self.model.decode(z)
# img = (img.clamp(-1.0, 1.0) + 1) * 0.5
return img
def forward(self, img, optimizer_idx=1):
return self.model.training_step(img, optimizer_idx=optimizer_idx)