Kemsekov commited on
Commit
2ee3330
·
verified ·
1 Parent(s): 3164379

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +23 -27
README.md CHANGED
@@ -27,21 +27,23 @@ from matplotlib import pyplot as plt
27
  import torch
28
  import torchvision.transforms as T
29
 
30
- sample = PIL.Image.open("sample_images/cat.png") # you sample image
31
  sample = T.ToTensor()(sample)[None,:] # add batch dimension
32
- sample = T.Resize((512,512))(sample) # optional, this vqvae works fine with any input image size
33
 
34
- vqvae=torch.jit.load("model.pt")
35
 
36
- # rec is reconstruction
37
- # z is list of latent space tensors
38
- # z_q is quantized list of latent space tensors
39
- # ind is list of encoded indices of quantized elements in latent space
 
 
 
40
 
41
- rec, z, z_q,ind = vqvae.eval().cpu()(sample)
42
- rec_ind = vqvae.decode_from_ind(ind)
43
- rec=rec.sigmoid()
44
- rec_ind=rec_ind.sigmoid()
45
 
46
  print("Original image shape",list(sample.shape[1:]))
47
  print("ind shapes",[list(v.shape[1:]) for v in ind])
@@ -52,6 +54,7 @@ plt.imshow(T.ToPILImage()(sample[0]).resize((256,256)))
52
  plt.title("original")
53
  plt.axis('off')
54
 
 
55
  plt.subplot(1,3,2)
56
  plt.imshow(T.ToPILImage()(rec[0]).resize((256,256)))
57
  plt.title("reconstruction")
@@ -64,6 +67,7 @@ plt.title("reconstruction from ind")
64
  plt.axis('off')
65
  plt.show()
66
 
 
67
  plt.figure(figsize=(18,6))
68
  plt.subplot(1,3,1)
69
  plt.imshow(T.ToPILImage()(ind[0]/512).resize((256,256)))
@@ -82,7 +86,7 @@ plt.axis('off')
82
  plt.show()
83
 
84
  print("latent space render")
85
- for z_ in z:
86
  dims = len(z_[0])
87
  dims_sqrt = int(dims**0.5)
88
  plt.figure(figsize=(10,10))
@@ -98,20 +102,12 @@ for z_ in z:
98
  ```
99
 
100
  ```
101
- Original image shape [3, 512, 512]
102
- ind shapes [[128, 128], [64, 64], [32, 32]]
103
  ```
104
 
105
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/d3PSfPu9tkKZkdMv8UJSV.png)
106
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/pDOPnZtAh05UXfkFaklkq.png)
107
-
108
- And it have following latent space
109
-
110
- Bottom
111
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/RkRVxY6uly59c8yumMTpv.png)
112
- Mid
113
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/CwR8o--prVLmR6TdL4Jt7.png)
114
- Top
115
- ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/uF95lUigW-NOYIV2EhD8h.png)
116
-
117
- As you can see, it properly handles different image aspects at different scales
 
27
  import torch
28
  import torchvision.transforms as T
29
 
30
+ sample = PIL.Image.open("image.png") # you sample image
31
  sample = T.ToTensor()(sample)[None,:] # add batch dimension
32
+ sample = T.RandomCrop((256,256))(sample) # this vqvae works fine with any input image size that is divisible by 8
33
 
34
+ vqvae=torch.jit.load("model_v3.pt")
35
 
36
+ # rec, rec_ind is reconstructions
37
+ # rec is reconstruction from latent space values z
38
+ # rec_ind is reconstruction from model predicted vector indices
39
+ # z latent space tensor with 64 channels and 4x smaller than input image
40
+ # z_layers is list of latent space tensors at different scales
41
+ # z_q_layers is quantized list of latent space tensors
42
+ # ind is list of encoded indices of quantized elements in latent space for each scale
43
 
44
+ z, z_layers,z_q_layers, ind = vqvae.encode(sample)
45
+ rec_ind = vqvae.decode_from_ind(ind).sigmoid()
46
+ rec = vqvae.decode(z).sigmoid()
 
47
 
48
  print("Original image shape",list(sample.shape[1:]))
49
  print("ind shapes",[list(v.shape[1:]) for v in ind])
 
54
  plt.title("original")
55
  plt.axis('off')
56
 
57
+ # these two must look the same
58
  plt.subplot(1,3,2)
59
  plt.imshow(T.ToPILImage()(rec[0]).resize((256,256)))
60
  plt.title("reconstruction")
 
67
  plt.axis('off')
68
  plt.show()
69
 
70
+ # this must look like a pile of mess
71
  plt.figure(figsize=(18,6))
72
  plt.subplot(1,3,1)
73
  plt.imshow(T.ToPILImage()(ind[0]/512).resize((256,256)))
 
86
  plt.show()
87
 
88
  print("latent space render")
89
+ for z_ in z_layers:
90
  dims = len(z_[0])
91
  dims_sqrt = int(dims**0.5)
92
  plt.figure(figsize=(10,10))
 
102
  ```
103
 
104
  ```
105
+ Original image shape [3, 256, 256]
106
+ ind shapes [[64, 64], [32, 32], [16, 16]]
107
  ```
108
 
109
+ Here is some examples at 256x256 resolution
110
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/-EEovEr-dxpp03YIloWSJ.png)
111
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/fPrS1L-aBN9yMYaTBjhUa.png)
112
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/jx4B0NfChsr4AzDh8XWl3.png)
113
+ ![image/png](https://cdn-uploads.huggingface.co/production/uploads/633b160acbdbadd99c094172/01Lsf-Zj_U4ULdMNnjGIj.png)