File size: 3,326 Bytes
5d2263b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from math import sqrt, log
from omegaconf import OmegaConf
import importlib

import torch
from torch import nn
import torch.nn.functional as F

from einops import rearrange

# helpers methods


def load_model(path):
    with open(path, "rb") as f:
        return torch.load(f, map_location=torch.device("cpu"))


def map_pixels(x, eps=0.1):
    return (1 - 2 * eps) * x + eps


def unmap_pixels(x, eps=0.1):
    return torch.clamp((x - eps) / (1 - 2 * eps), 0, 1)


def make_contiguous(module):
    with torch.no_grad():
        for param in module.parameters():
            param.set_(param.contiguous())


# VQGAN from Taming Transformers paper
# https://arxiv.org/abs/2012.09841


def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
    if not "target" in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))


class VQGanVAE(nn.Module):
    def __init__(self, vqgan_model_path=None, vqgan_config_path=None, channels=1):
        super().__init__()

        assert vqgan_config_path is not None

        model_path = vqgan_model_path
        config_path = vqgan_config_path

        config = OmegaConf.load(config_path)

        model = instantiate_from_config(config["model"])

        if vqgan_model_path:

            state = torch.load(model_path, map_location="cpu")["state_dict"]
            model.load_state_dict(state, strict=True)

        print(f"Loaded VQGAN from {model_path} and {config_path}")

        self.model = model

        # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
        f = (
            config.model.params.ddconfig.resolution
            / config.model.params.ddconfig.attn_resolutions[0]
        )
        self.num_layers = int(log(f) / log(2))
        self.image_size = config.model.params.ddconfig.resolution
        self.num_tokens = config.model.params.n_embed
        # self.is_gumbel = isinstance(self.model, GumbelVQ)
        self.is_gumbel = False
        self.channels = config.model.params.ddconfig.in_channels

    def encode(self, img):
        return self.model.encode(img)

    def get_codebook_indices(self, img):
        b = img.shape[0]
        # img = (2 * img) - 1
        _, _, [_, _, indices] = self.encode(img)
        if self.is_gumbel:
            return rearrange(indices, "b h w -> b (h w)", b=b)
        return rearrange(indices, "(b n) -> b n", b=b)

    def decode(self, img_seq):
        b, n = img_seq.shape
        one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float()
        z = (
            one_hot_indices @ self.model.quantize.embed.weight
            if self.is_gumbel
            else (one_hot_indices @ self.model.quantize.embedding.weight)
        )

        z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n)))
        img = self.model.decode(z)

        # img = (img.clamp(-1.0, 1.0) + 1) * 0.5
        return img

    def forward(self, img, optimizer_idx=1):
        return self.model.training_step(img, optimizer_idx=optimizer_idx)