Kemsekov commited on
Commit
541ccc4
·
verified ·
1 Parent(s): 0a7cbdb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +104 -3
README.md CHANGED
@@ -1,3 +1,104 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ Custom hand-made 3-scale VQVAE trained on private dataset that consists of about 4k images pixelart images.
6
+ Source code for model can be found [here](https://github.com/Kemsekov/kemsekov_torch/tree/main/vqvae).
7
+
8
+
9
+ It acrhived 0.987 r2 metric on image reconstruction.
10
+
11
+ Model have codebook:
12
+ * 512 bottom
13
+ * 512 mid
14
+ * 256 top
15
+
16
+ This provides enough space for model to achieve good metrics.
17
+
18
+ Here is code example how to use it.
19
+
20
+
21
+ ```py
22
+ import random
23
+ import PIL.Image
24
+ from matplotlib import pyplot as plt
25
+ import torch
26
+ import torchvision.transforms as T
27
+
28
+ sample = PIL.Image.open("sample_images/cat.png") # you sample image
29
+ sample = T.ToTensor()(sample)[None,:] # add batch dimension
30
+ sample = T.Resize((512,512))(sample) # optional, this vqvae works fine with any input image size
31
+
32
+ vqvae=torch.jit.load("model.pt")
33
+
34
+ # rec is reconstruction
35
+ # z is list of latent space tensors
36
+ # z_q is quantized list of latent space tensors
37
+ # ind is list of encoded indices of quantized elements in latent space
38
+
39
+ rec, z, z_q,ind = vqvae.eval().cpu()(sample)
40
+ rec_ind = vqvae.decode_from_ind(ind)
41
+ rec=rec.sigmoid()
42
+ rec_ind=rec_ind.sigmoid()
43
+
44
+ print("Original image shape",list(sample.shape[1:]))
45
+ print("ind shapes",[list(v.shape[1:]) for v in ind])
46
+
47
+ plt.figure(figsize=(18,6))
48
+ plt.subplot(1,3,1)
49
+ plt.imshow(T.ToPILImage()(sample[0]).resize((256,256)))
50
+ plt.title("original")
51
+ plt.axis('off')
52
+
53
+ plt.subplot(1,3,2)
54
+ plt.imshow(T.ToPILImage()(rec[0]).resize((256,256)))
55
+ plt.title("reconstruction")
56
+ plt.axis('off')
57
+
58
+
59
+ plt.subplot(1,3,3)
60
+ plt.imshow(T.ToPILImage()(rec_ind[0]).resize((256,256)))
61
+ plt.title("reconstruction from ind")
62
+ plt.axis('off')
63
+ plt.show()
64
+
65
+ plt.figure(figsize=(18,6))
66
+ plt.subplot(1,3,1)
67
+ plt.imshow(T.ToPILImage()(ind[0]/512).resize((256,256)))
68
+ plt.title("ind0")
69
+ plt.axis('off')
70
+
71
+ plt.subplot(1,3,2)
72
+ plt.imshow(T.ToPILImage()(ind[1]/512).resize((256,256)))
73
+ plt.title("ind1")
74
+ plt.axis('off')
75
+
76
+ plt.subplot(1,3,3)
77
+ plt.imshow(T.ToPILImage()(ind[2]/256).resize((256,256)))
78
+ plt.title("ind2")
79
+ plt.axis('off')
80
+ plt.show()
81
+
82
+ print("latent space render")
83
+ for z_ in z:
84
+ dims = len(z_[0])
85
+ dims_sqrt = int(dims**0.5)
86
+ plt.figure(figsize=(10,10))
87
+ plt.axis('off')
88
+ for i in range(dims_sqrt):
89
+ for j in range(dims_sqrt):
90
+ slice_ind = i*dims_sqrt+j
91
+ slice_ind_end = slice_ind+1
92
+ plt.subplot(dims_sqrt,dims_sqrt,slice_ind+1)
93
+ plt.imshow(T.ToPILImage()(z_[0][slice_ind:slice_ind_end]))
94
+ plt.axis('off')
95
+ plt.show()
96
+ ```
97
+
98
+ ```
99
+ Original image shape [3, 512, 512]
100
+ ind shapes [[128, 128], [64, 64], [32, 32]]
101
+ ```
102
+
103
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/d3PSfPu9tkKZkdMv8UJSV.png)
104
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/pDOPnZtAh05UXfkFaklkq.png)