Imag / src /videogen_hub /depend /icetk /image_tokenizer.py
Baraaqasem's picture
Upload 49 files
413d4d0 verified
# -*- encoding: utf-8 -*-
'''
@File : image_tokenizer.py
@Time : 2021/12/20 14:19:49
@Author : Ming Ding
@Contact : [email protected]
'''
# here put the import lib
import os
import sys
import math
import random
import torch
import torch.nn.functional as F
from torchvision import transforms
from .vqvae import load_default_HVQVAE, load_ckpt
class ImageTokenizer(object):
def __init__(self,
model_path,
device='cuda',
fp16=True):
model = load_default_HVQVAE()
model = load_ckpt(model, model_path)
model = model.to(device)
model.eval()
self.tr_normalize = transforms.Normalize(
[0.79093, 0.76271, 0.75340],
[0.30379, 0.32279, 0.32800]
)
self.model = model
self.device = device
self.fp16 = fp16
self.num_tokens = model.quantize.n_embed
if fp16:
model = model.half()
def __len__(self):
return self.num_tokens
def encode(self, image_torch, l=1):
'''Convert a batch of img to code
Args:
model: The tokenizer model.
img: [b, c, h, w]
'''
if len(image_torch.shape) == 3:
image_torch = image_torch.unsqueeze(0)
img = self.tr_normalize(image_torch).to(self.device)
if self.fp16:
img = img.half()
with torch.no_grad():
quant, diff, id = self.model.single_encode(img, l)
return id.view(img.shape[0], -1)
def decode(self, codes, l=1):
'''Convert a batch of code to imgs
Args:
codes : [b, h, w] or [b, h*w] or [h*w] LongTensor / list
'''
if isinstance(codes, list):
codes = torch.tensor(codes, dtype=torch.long, device=self.device)
if len(codes.shape) == 1:
codes = codes.unsqueeze(0)
if len(codes.shape) == 2:
s = int(math.sqrt(len(codes.view(-1))) + 1e-5)
codes = codes.view(codes.shape[0], s, s)
with torch.no_grad():
out = self.model.single_decode_code(codes, l)
out = out * torch.tensor([0.30379, 0.32279, 0.32800], device=out.device).view(1, -1, 1, 1) + torch.tensor([0.79093, 0.76271, 0.75340], device=out.device).view(1, -1, 1, 1)
return out