File size: 9,145 Bytes
246c106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import os

import cv2
import numpy as np
import torch
import torchvision.transforms.v2.functional as transforms_f
from diffusers import AutoencoderKLTemporalDecoder
from einops import rearrange
from transformers import T5Tokenizer, T5Model

from magvit2.config import VQConfig
from magvit2.models.lfqgan import VQModel

vision_model = None


def get_image_encoder(encoder_type: str, encoder_name_or_path: str):
    encoder_type = encoder_type.lower()
    if encoder_type == "magvit":
        return VQModel(VQConfig(), ckpt_path=encoder_name_or_path)
    elif encoder_type == "temporalvae":
        return AutoencoderKLTemporalDecoder.from_pretrained(encoder_name_or_path, subfolder="vae")
    else:
        raise NotImplementedError(f"{encoder_type=}")


def set_seed(seed):
    # set seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)


def mkdir_if_missing(dst_dir):
    """make destination folder if it's missing"""
    if not os.path.exists(dst_dir):
        os.makedirs(dst_dir)

def resize_image(image, resize=True):
    MAX_RES = 1024

    # convert to array
    image = np.asarray(image)
    h, w = image.shape[:2]
    if h > MAX_RES or w > MAX_RES:
        if h < w:
            new_h, new_w = int(MAX_RES * w / h), MAX_RES
        else:
            new_h, new_w = MAX_RES, int(MAX_RES * h / w)
        image = cv2.resize(image, (new_w, new_h))

    if resize:
        # resize the shorter side to 256 and then do a center crop
        h, w = image.shape[:2]
        if h < w:
            new_h, new_w = 256, int(256 * w / h)
        else:
            new_h, new_w = int(256 * h / w), 256
        image = cv2.resize(image, (new_w, new_h))

        h, w = image.shape[:2]
        crop_h, crop_w = 256, 256
        start_h = (h - crop_h) // 2
        start_w = (w - crop_w) // 2
        image = image[start_h:start_h + crop_h, start_w:start_w + crop_w]
    return image

def normalize_image(image, resize=True):
    """
    H x W x 3(uint8) -> imagenet normalized 3 x H x W

    Normalizes image to [-1, 1].
    Resizes the image if resize=True or if the image resolution > MAX_RES
    """
    image = resize_image(image, resize=resize)
    # normalize between -1 and 1
    image = image / 255.0
    image = (image * 2 - 1.)
    return torch.from_numpy(image.transpose(2, 0, 1))


def unnormalize_image(magvit_output):
    """
    [-1, 1] -> [0, 255]

    Important: clip to [0, 255]
    """
    rescaled_output = ((magvit_output.detach().cpu() + 1) * 127.5)
    clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
    return clipped_output

@torch.inference_mode()
@torch.no_grad()
def get_quantized_image_embeddings(
    image,
    encoder_type,
    encoder_name_or_path,
    keep_res=False,
    device="cuda",
):
    """
    image: (h, w, 3)
    """
    global vision_model
    DEBUG = False
    dtype = torch.bfloat16

    if vision_model is None:
        vision_model = get_image_encoder(encoder_type=encoder_type, encoder_name_or_path=encoder_name_or_path)
        vision_model = vision_model.to(device=device, dtype=dtype)
        vision_model.eval()

    batch = normalize_image(image, resize=not keep_res)[None]
    if not keep_res:
        img_h, img_w = 256, 256
    else:
        img_h, img_w = batch.shape[2:]

    h, w = img_h // 16, img_w // 16

    with vision_model.ema_scope():
        quant_, _, indices, _ = vision_model.encode(batch.to(device=device, dtype=dtype), flip=True)
    indices = rearrange(indices, "(h w) -> h w", h=h, w=w)

    # alternative way to get indices
    # indices_ = vision_model.quantize.bits_to_indices(quant_.permute(0, 2, 3, 1) > 0).cpu().numpy()
    # indices_ = rearrange(indices_, "(h w) -> h w", h=h, w=w)

    if DEBUG:
        # sanity check: decode and then visualize
        with vision_model.ema_scope():
            indices = indices[None]
            # bit representations
            quant = vision_model.quantize.get_codebook_entry(rearrange(indices, "b h w -> b (h w)"),
                                                        bhwc=indices.shape + (vision_model.quantize.codebook_dim,)).flip(1)
                                                        ##  why is there a flip(1) needed for the codebook bits?
            decoded_img = unnormalize_image(vision_model.decode(quant.to(device=device, dtype=dtype)))
            transforms_f.to_pil_image(decoded_img[0]).save("decoded.png")
            transforms_f.to_pil_image(image).save("original.png") # show()

    # 18 x 16 x 16 of [-1., 1.] - > 16 x 16 uint32
    indices = indices.type(torch.int32)
    indices = indices.detach().cpu().numpy().astype(np.uint32)
    return indices


@torch.inference_mode()
@torch.no_grad()
def get_vae_image_embeddings(
    image,
    encoder_type,
    encoder_name_or_path,
    keep_res: bool = False,
    device="cuda",
):
    """
    image: (h, w, 3), in [-1, 1]
    use SD VAE to encode and decode the images.
    """
    global vision_model
    DEBUG = False
    dtype = torch.bfloat16

    if vision_model is None:
        vision_model = get_image_encoder(encoder_type, encoder_name_or_path)
        vision_model = vision_model.to(device=device, dtype=dtype)
        vision_model.eval()

    # https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
    # if args.use_temporal_decoder:
    #     vae = AutoencoderKLTemporalDecoder.from_pretrained(args.vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
    # else:
    #     vae = AutoencoderKL.from_pretrained(args.vae_model_path, subfolder="vae").to(device)
    #  x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor) ?

    batch = normalize_image(image, resize=not keep_res)[None]

    if isinstance(vision_model, AutoencoderKLTemporalDecoder):
        # Think SVD expects images in [-1, 1] so we don't have to change anything?
        # https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/video_sampling.py#L182
        # https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/streamlit_helpers.py#L894
        z = vision_model.encode(batch.to(device=device, dtype=dtype)).latent_dist.mean
    elif isinstance(vision_model, VQModel):  # vision_model should be VQModel
        # with vision_model.ema_scope():  # doesn't matter due to bugged VQModel ckpt_path arg
        z = vision_model.encode_without_quantize(batch.to(device=device, dtype=dtype))
    else:
        raise NotImplementedError(f"{vision_model=}")

    if DEBUG:
        decoded_img = unnormalize_image(vision_model.decode(z.to(device=device, dtype=dtype)))
        transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
        transforms_f.to_pil_image(image).save("original.png")

    return z[0].detach().cpu().float().numpy().astype(np.float16)

    # switch to VAE in SD
    # https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae
    # https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
    # from diffusers.models import AutoencoderKL,AutoencoderKLTemporalDecoder
    # vae_model_path = 'pretrained_models/stabilityai/stable-diffusion-xl-base-1.0'
    # if args.use_temporal_decoder:
    #     vae = AutoencoderKLTemporalDecoder.from_pretrained(vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
    # else:
    #     vae = AutoencoderKL.from_pretrained(vae_model_path, subfolder="vae").to(device)
    #  z = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
    # if DEBUG:
    #     decoded_img = unnormalize_image(vae.decode(z.to(device=device, dtype=dtype) / vae.config.scaling_factor))
    #     transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
    #     transforms_f.to_pil_image(image).save("original.png")


@torch.no_grad()
def get_t5_embeddings(language, per_token=True, max_length=16, device="cpu"):
    """Get T5 embedding"""
    global global_language_model, t5_tok
    if global_language_model is None:
        try:
            t5_model = T5Model.from_pretrained("t5-base")
            t5_tok = T5Tokenizer.from_pretrained("t5-base")
        except:
            t5_model = T5Model.from_pretrained("t5-base", local_files_only=True)
            t5_tok = T5Tokenizer.from_pretrained("t5-base", local_files_only=True)
        t5_model = t5_model.to(device)
        global_language_model = t5_model
        global_language_model.eval()

    # forward pass through encoder only
    enc = t5_tok(
        [language],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=max_length,
    ).to(device)

    output = global_language_model.encoder(
        input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], return_dict=True
    )
    torch.cuda.empty_cache()
    if per_token:
        return output.last_hidden_state[0].detach().cpu().numpy()
    else:
        # get the final hidden states. average across tokens.
        emb = output.last_hidden_state[0].mean(dim=0).detach().cpu().numpy()
        return emb