hma / datasets /utils.py
LeroyWaa's picture
draft
246c106
raw
history blame
9.15 kB
import os
import cv2
import numpy as np
import torch
import torchvision.transforms.v2.functional as transforms_f
from diffusers import AutoencoderKLTemporalDecoder
from einops import rearrange
from transformers import T5Tokenizer, T5Model
from magvit2.config import VQConfig
from magvit2.models.lfqgan import VQModel
vision_model = None
def get_image_encoder(encoder_type: str, encoder_name_or_path: str):
encoder_type = encoder_type.lower()
if encoder_type == "magvit":
return VQModel(VQConfig(), ckpt_path=encoder_name_or_path)
elif encoder_type == "temporalvae":
return AutoencoderKLTemporalDecoder.from_pretrained(encoder_name_or_path, subfolder="vae")
else:
raise NotImplementedError(f"{encoder_type=}")
def set_seed(seed):
# set seed for reproducibility
torch.manual_seed(seed)
np.random.seed(seed)
def mkdir_if_missing(dst_dir):
"""make destination folder if it's missing"""
if not os.path.exists(dst_dir):
os.makedirs(dst_dir)
def resize_image(image, resize=True):
MAX_RES = 1024
# convert to array
image = np.asarray(image)
h, w = image.shape[:2]
if h > MAX_RES or w > MAX_RES:
if h < w:
new_h, new_w = int(MAX_RES * w / h), MAX_RES
else:
new_h, new_w = MAX_RES, int(MAX_RES * h / w)
image = cv2.resize(image, (new_w, new_h))
if resize:
# resize the shorter side to 256 and then do a center crop
h, w = image.shape[:2]
if h < w:
new_h, new_w = 256, int(256 * w / h)
else:
new_h, new_w = int(256 * h / w), 256
image = cv2.resize(image, (new_w, new_h))
h, w = image.shape[:2]
crop_h, crop_w = 256, 256
start_h = (h - crop_h) // 2
start_w = (w - crop_w) // 2
image = image[start_h:start_h + crop_h, start_w:start_w + crop_w]
return image
def normalize_image(image, resize=True):
"""
H x W x 3(uint8) -> imagenet normalized 3 x H x W
Normalizes image to [-1, 1].
Resizes the image if resize=True or if the image resolution > MAX_RES
"""
image = resize_image(image, resize=resize)
# normalize between -1 and 1
image = image / 255.0
image = (image * 2 - 1.)
return torch.from_numpy(image.transpose(2, 0, 1))
def unnormalize_image(magvit_output):
"""
[-1, 1] -> [0, 255]
Important: clip to [0, 255]
"""
rescaled_output = ((magvit_output.detach().cpu() + 1) * 127.5)
clipped_output = torch.clamp(rescaled_output, 0, 255).to(dtype=torch.uint8)
return clipped_output
@torch.inference_mode()
@torch.no_grad()
def get_quantized_image_embeddings(
image,
encoder_type,
encoder_name_or_path,
keep_res=False,
device="cuda",
):
"""
image: (h, w, 3)
"""
global vision_model
DEBUG = False
dtype = torch.bfloat16
if vision_model is None:
vision_model = get_image_encoder(encoder_type=encoder_type, encoder_name_or_path=encoder_name_or_path)
vision_model = vision_model.to(device=device, dtype=dtype)
vision_model.eval()
batch = normalize_image(image, resize=not keep_res)[None]
if not keep_res:
img_h, img_w = 256, 256
else:
img_h, img_w = batch.shape[2:]
h, w = img_h // 16, img_w // 16
with vision_model.ema_scope():
quant_, _, indices, _ = vision_model.encode(batch.to(device=device, dtype=dtype), flip=True)
indices = rearrange(indices, "(h w) -> h w", h=h, w=w)
# alternative way to get indices
# indices_ = vision_model.quantize.bits_to_indices(quant_.permute(0, 2, 3, 1) > 0).cpu().numpy()
# indices_ = rearrange(indices_, "(h w) -> h w", h=h, w=w)
if DEBUG:
# sanity check: decode and then visualize
with vision_model.ema_scope():
indices = indices[None]
# bit representations
quant = vision_model.quantize.get_codebook_entry(rearrange(indices, "b h w -> b (h w)"),
bhwc=indices.shape + (vision_model.quantize.codebook_dim,)).flip(1)
## why is there a flip(1) needed for the codebook bits?
decoded_img = unnormalize_image(vision_model.decode(quant.to(device=device, dtype=dtype)))
transforms_f.to_pil_image(decoded_img[0]).save("decoded.png")
transforms_f.to_pil_image(image).save("original.png") # show()
# 18 x 16 x 16 of [-1., 1.] - > 16 x 16 uint32
indices = indices.type(torch.int32)
indices = indices.detach().cpu().numpy().astype(np.uint32)
return indices
@torch.inference_mode()
@torch.no_grad()
def get_vae_image_embeddings(
image,
encoder_type,
encoder_name_or_path,
keep_res: bool = False,
device="cuda",
):
"""
image: (h, w, 3), in [-1, 1]
use SD VAE to encode and decode the images.
"""
global vision_model
DEBUG = False
dtype = torch.bfloat16
if vision_model is None:
vision_model = get_image_encoder(encoder_type, encoder_name_or_path)
vision_model = vision_model.to(device=device, dtype=dtype)
vision_model.eval()
# https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
# if args.use_temporal_decoder:
# vae = AutoencoderKLTemporalDecoder.from_pretrained(args.vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
# else:
# vae = AutoencoderKL.from_pretrained(args.vae_model_path, subfolder="vae").to(device)
# x = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor) ?
batch = normalize_image(image, resize=not keep_res)[None]
if isinstance(vision_model, AutoencoderKLTemporalDecoder):
# Think SVD expects images in [-1, 1] so we don't have to change anything?
# https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/video_sampling.py#L182
# https://github.com/Stability-AI/generative-models/blob/1659a1c09b0953ad9cc0d480f42e4526c5575b37/scripts/demo/streamlit_helpers.py#L894
z = vision_model.encode(batch.to(device=device, dtype=dtype)).latent_dist.mean
elif isinstance(vision_model, VQModel): # vision_model should be VQModel
# with vision_model.ema_scope(): # doesn't matter due to bugged VQModel ckpt_path arg
z = vision_model.encode_without_quantize(batch.to(device=device, dtype=dtype))
else:
raise NotImplementedError(f"{vision_model=}")
if DEBUG:
decoded_img = unnormalize_image(vision_model.decode(z.to(device=device, dtype=dtype)))
transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
transforms_f.to_pil_image(image).save("original.png")
return z[0].detach().cpu().float().numpy().astype(np.float16)
# switch to VAE in SD
# https://huggingface.co/stabilityai/stable-diffusion-3.5-large/tree/main/vae
# https://github.com/bytedance/IRASim/blob/main/sample/sample_autoregressive.py#L151
# from diffusers.models import AutoencoderKL,AutoencoderKLTemporalDecoder
# vae_model_path = 'pretrained_models/stabilityai/stable-diffusion-xl-base-1.0'
# if args.use_temporal_decoder:
# vae = AutoencoderKLTemporalDecoder.from_pretrained(vae_model_path, subfolder="t2v_required_models/vae_temporal_decoder").to(device)
# else:
# vae = AutoencoderKL.from_pretrained(vae_model_path, subfolder="vae").to(device)
# z = vae.encode(x).latent_dist.sample().mul_(vae.config.scaling_factor)
# if DEBUG:
# decoded_img = unnormalize_image(vae.decode(z.to(device=device, dtype=dtype) / vae.config.scaling_factor))
# transforms_f.to_pil_image(decoded_img[0]).save("decoded_unquant.png")
# transforms_f.to_pil_image(image).save("original.png")
@torch.no_grad()
def get_t5_embeddings(language, per_token=True, max_length=16, device="cpu"):
"""Get T5 embedding"""
global global_language_model, t5_tok
if global_language_model is None:
try:
t5_model = T5Model.from_pretrained("t5-base")
t5_tok = T5Tokenizer.from_pretrained("t5-base")
except:
t5_model = T5Model.from_pretrained("t5-base", local_files_only=True)
t5_tok = T5Tokenizer.from_pretrained("t5-base", local_files_only=True)
t5_model = t5_model.to(device)
global_language_model = t5_model
global_language_model.eval()
# forward pass through encoder only
enc = t5_tok(
[language],
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=max_length,
).to(device)
output = global_language_model.encoder(
input_ids=enc["input_ids"], attention_mask=enc["attention_mask"], return_dict=True
)
torch.cuda.empty_cache()
if per_token:
return output.last_hidden_state[0].detach().cpu().numpy()
else:
# get the final hidden states. average across tokens.
emb = output.last_hidden_state[0].mean(dim=0).detach().cpu().numpy()
return emb