File size: 3,566 Bytes
541ccc4
 
 
 
 
 
 
 
4544857
 
 
541ccc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee3330
541ccc4
2ee3330
541ccc4
2ee3330
541ccc4
2ee3330
 
 
 
 
 
 
541ccc4
2ee3330
 
 
541ccc4
 
 
 
 
 
 
 
 
 
2ee3330
541ccc4
 
 
 
 
 
 
 
 
 
 
 
2ee3330
541ccc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee3330
541ccc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ee3330
 
541ccc4
 
2ee3330
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
---
license: mit
---

Custom hand-made 3-scale VQVAE trained on private dataset that consists of about 4k images pixelart images.
Source code for model can be found [here](https://github.com/Kemsekov/kemsekov_torch/tree/main/vqvae).


It acrhived 0.987 r2 metric on image reconstruction in 500 epoch on 256x256 images crops.

Because I used crops, this model works fine with larger and smaller images as well.

Model have codebook:
* 512 bottom
* 512 mid
* 256 top

This provides enough space for model to achieve good metrics.

Here is code example how to use it.


```py
import random
import PIL.Image
from matplotlib import pyplot as plt
import torch
import torchvision.transforms as T

sample = PIL.Image.open("image.png") # you sample image
sample = T.ToTensor()(sample)[None,:] # add batch dimension
sample = T.RandomCrop((256,256))(sample) # this vqvae works fine with any input image size that is divisible by 8

vqvae=torch.jit.load("model_v3.pt")

# rec, rec_ind is reconstructions
# rec is reconstruction from latent space values z
# rec_ind is reconstruction from model predicted vector indices
# z latent space tensor with 64 channels and 4x smaller than input image
# z_layers is list of latent space tensors at different scales
# z_q_layers is quantized list of latent space tensors
# ind is list of encoded indices of quantized elements in latent space for each scale

z, z_layers,z_q_layers, ind = vqvae.encode(sample)
rec_ind = vqvae.decode_from_ind(ind).sigmoid()
rec = vqvae.decode(z).sigmoid()

print("Original image shape",list(sample.shape[1:]))
print("ind shapes",[list(v.shape[1:]) for v in ind])

plt.figure(figsize=(18,6))
plt.subplot(1,3,1)
plt.imshow(T.ToPILImage()(sample[0]).resize((256,256)))
plt.title("original")
plt.axis('off')

# these two must look the same
plt.subplot(1,3,2)
plt.imshow(T.ToPILImage()(rec[0]).resize((256,256)))
plt.title("reconstruction")
plt.axis('off')


plt.subplot(1,3,3)
plt.imshow(T.ToPILImage()(rec_ind[0]).resize((256,256)))
plt.title("reconstruction from ind")
plt.axis('off')
plt.show()

# this must look like a pile of mess
plt.figure(figsize=(18,6))
plt.subplot(1,3,1)
plt.imshow(T.ToPILImage()(ind[0]/512).resize((256,256)))
plt.title("ind0")
plt.axis('off')

plt.subplot(1,3,2)
plt.imshow(T.ToPILImage()(ind[1]/512).resize((256,256)))
plt.title("ind1")
plt.axis('off')

plt.subplot(1,3,3)
plt.imshow(T.ToPILImage()(ind[2]/256).resize((256,256)))
plt.title("ind2")
plt.axis('off')
plt.show()

print("latent space render")
for z_ in z_layers:
    dims = len(z_[0])
    dims_sqrt = int(dims**0.5)
    plt.figure(figsize=(10,10))
    plt.axis('off')
    for i in range(dims_sqrt):
        for j in range(dims_sqrt):
                slice_ind = i*dims_sqrt+j
                slice_ind_end = slice_ind+1
                plt.subplot(dims_sqrt,dims_sqrt,slice_ind+1)
                plt.imshow(T.ToPILImage()(z_[0][slice_ind:slice_ind_end]))
                plt.axis('off')    
    plt.show()
```

```
Original image shape [3, 256, 256]
ind shapes [[64, 64], [32, 32], [16, 16]]
```

Here is some examples at 256x256 resolution
![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/-EEovEr-dxpp03YIloWSJ.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/fPrS1L-aBN9yMYaTBjhUa.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/jx4B0NfChsr4AzDh8XWl3.png)
![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/01Lsf-Zj_U4ULdMNnjGIj.png)