Custom hand-made 3-scale VQVAE trained on private dataset that consists of about 4k images pixelart images. Source code for model can be found here.
It acrhived 0.987 r2 metric on image reconstruction in 500 epoch on 256x256 images crops.
Because I used crops, this model works fine with larger and smaller images as well.
Model have codebook:
- 512 bottom
- 512 mid
- 256 top
This provides enough space for model to achieve good metrics.
Here is code example how to use it.
import random
import PIL.Image
from matplotlib import pyplot as plt
import torch
import torchvision.transforms as T
sample = PIL.Image.open("image.png") # you sample image
sample = T.ToTensor()(sample)[None,:] # add batch dimension
sample = T.RandomCrop((256,256))(sample) # this vqvae works fine with any input image size that is divisible by 8
vqvae=torch.jit.load("model_v3.pt")
# rec, rec_ind is reconstructions
# rec is reconstruction from latent space values z
# rec_ind is reconstruction from model predicted vector indices
# z latent space tensor with 64 channels and 4x smaller than input image
# z_layers is list of latent space tensors at different scales
# z_q_layers is quantized list of latent space tensors
# ind is list of encoded indices of quantized elements in latent space for each scale
z, z_layers,z_q_layers, ind = vqvae.encode(sample)
rec_ind = vqvae.decode_from_ind(ind).sigmoid()
rec = vqvae.decode(z).sigmoid()
print("Original image shape",list(sample.shape[1:]))
print("ind shapes",[list(v.shape[1:]) for v in ind])
plt.figure(figsize=(18,6))
plt.subplot(1,3,1)
plt.imshow(T.ToPILImage()(sample[0]).resize((256,256)))
plt.title("original")
plt.axis('off')
# these two must look the same
plt.subplot(1,3,2)
plt.imshow(T.ToPILImage()(rec[0]).resize((256,256)))
plt.title("reconstruction")
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(T.ToPILImage()(rec_ind[0]).resize((256,256)))
plt.title("reconstruction from ind")
plt.axis('off')
plt.show()
# this must look like a pile of mess
plt.figure(figsize=(18,6))
plt.subplot(1,3,1)
plt.imshow(T.ToPILImage()(ind[0]/512).resize((256,256)))
plt.title("ind0")
plt.axis('off')
plt.subplot(1,3,2)
plt.imshow(T.ToPILImage()(ind[1]/512).resize((256,256)))
plt.title("ind1")
plt.axis('off')
plt.subplot(1,3,3)
plt.imshow(T.ToPILImage()(ind[2]/256).resize((256,256)))
plt.title("ind2")
plt.axis('off')
plt.show()
print("latent space render")
for z_ in z_layers:
dims = len(z_[0])
dims_sqrt = int(dims**0.5)
plt.figure(figsize=(10,10))
plt.axis('off')
for i in range(dims_sqrt):
for j in range(dims_sqrt):
slice_ind = i*dims_sqrt+j
slice_ind_end = slice_ind+1
plt.subplot(dims_sqrt,dims_sqrt,slice_ind+1)
plt.imshow(T.ToPILImage()(z_[0][slice_ind:slice_ind_end]))
plt.axis('off')
plt.show()
Original image shape [3, 256, 256]
ind shapes [[64, 64], [32, 32], [16, 16]]
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support
HF Inference deployability: The model has no library tag.