csuhan commited on
Commit
3092e43
·
verified ·
1 Parent(s): 71299d8

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. .gitattributes +2 -0
  2. kl16.safetensors +3 -0
  3. rec.png +3 -0
  4. rec_tanh.png +3 -0
  5. vae.py +546 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ rec.png filter=lfs diff=lfs merge=lfs -text
37
+ rec_tanh.png filter=lfs diff=lfs merge=lfs -text
kl16.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae6e66b66cce64248cace5aa6cf8bdb3e834947801cbf7e059028592e2b1de31
3
+ size 265856244
rec.png ADDED

Git LFS Details

  • SHA256: 7b2d6606ee09136dac682615de3248b9b4ef6b16be0178444fa10695cf336c92
  • Pointer size: 132 Bytes
  • Size of remote file: 3.41 MB
rec_tanh.png ADDED

Git LFS Details

  • SHA256: 3058a4748ba9e74a41c87972be748346cb0e4e3a6ac117c395bc3a56ecb9c059
  • Pointer size: 132 Bytes
  • Size of remote file: 3.39 MB
vae.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adopted from LDM's KL-VAE: https://github.com/CompVis/latent-diffusion
2
+ import torch
3
+ import torch.nn as nn
4
+ from safetensors.torch import load_file
5
+
6
+ import numpy as np
7
+
8
+
9
+ def nonlinearity(x):
10
+ # swish
11
+ return x * torch.sigmoid(x)
12
+
13
+
14
+ def Normalize(in_channels, num_groups=32):
15
+ return torch.nn.GroupNorm(
16
+ num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
17
+ )
18
+
19
+
20
+ class Upsample(nn.Module):
21
+ def __init__(self, in_channels, with_conv):
22
+ super().__init__()
23
+ self.with_conv = with_conv
24
+ if self.with_conv:
25
+ self.conv = torch.nn.Conv2d(
26
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
27
+ )
28
+
29
+ def forward(self, x):
30
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
31
+ if self.with_conv:
32
+ x = self.conv(x)
33
+ return x
34
+
35
+
36
+ class Downsample(nn.Module):
37
+ def __init__(self, in_channels, with_conv):
38
+ super().__init__()
39
+ self.with_conv = with_conv
40
+ if self.with_conv:
41
+ # no asymmetric padding in torch conv, must do it ourselves
42
+ self.conv = torch.nn.Conv2d(
43
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
44
+ )
45
+
46
+ def forward(self, x):
47
+ if self.with_conv:
48
+ pad = (0, 1, 0, 1)
49
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
50
+ x = self.conv(x)
51
+ else:
52
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
53
+ return x
54
+
55
+
56
+ class ResnetBlock(nn.Module):
57
+ def __init__(
58
+ self,
59
+ *,
60
+ in_channels,
61
+ out_channels=None,
62
+ conv_shortcut=False,
63
+ dropout,
64
+ temb_channels=512,
65
+ ):
66
+ super().__init__()
67
+ self.in_channels = in_channels
68
+ out_channels = in_channels if out_channels is None else out_channels
69
+ self.out_channels = out_channels
70
+ self.use_conv_shortcut = conv_shortcut
71
+
72
+ self.norm1 = Normalize(in_channels)
73
+ self.conv1 = torch.nn.Conv2d(
74
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
75
+ )
76
+ if temb_channels > 0:
77
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
78
+ self.norm2 = Normalize(out_channels)
79
+ self.dropout = torch.nn.Dropout(dropout)
80
+ self.conv2 = torch.nn.Conv2d(
81
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
82
+ )
83
+ if self.in_channels != self.out_channels:
84
+ if self.use_conv_shortcut:
85
+ self.conv_shortcut = torch.nn.Conv2d(
86
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
87
+ )
88
+ else:
89
+ self.nin_shortcut = torch.nn.Conv2d(
90
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
91
+ )
92
+
93
+ def forward(self, x, temb):
94
+ h = x
95
+ h = self.norm1(h)
96
+ h = nonlinearity(h)
97
+ h = self.conv1(h)
98
+
99
+ if temb is not None:
100
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
101
+
102
+ h = self.norm2(h)
103
+ h = nonlinearity(h)
104
+ h = self.dropout(h)
105
+ h = self.conv2(h)
106
+
107
+ if self.in_channels != self.out_channels:
108
+ if self.use_conv_shortcut:
109
+ x = self.conv_shortcut(x)
110
+ else:
111
+ x = self.nin_shortcut(x)
112
+
113
+ return x + h
114
+
115
+
116
+ class AttnBlock(nn.Module):
117
+ def __init__(self, in_channels):
118
+ super().__init__()
119
+ self.in_channels = in_channels
120
+
121
+ self.norm = Normalize(in_channels)
122
+ self.q = torch.nn.Conv2d(
123
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
124
+ )
125
+ self.k = torch.nn.Conv2d(
126
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
127
+ )
128
+ self.v = torch.nn.Conv2d(
129
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
130
+ )
131
+ self.proj_out = torch.nn.Conv2d(
132
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
133
+ )
134
+
135
+ def forward(self, x):
136
+ h_ = x
137
+ h_ = self.norm(h_)
138
+ q = self.q(h_)
139
+ k = self.k(h_)
140
+ v = self.v(h_)
141
+
142
+ # compute attention
143
+ b, c, h, w = q.shape
144
+ q = q.reshape(b, c, h * w)
145
+ q = q.permute(0, 2, 1) # b,hw,c
146
+ k = k.reshape(b, c, h * w) # b,c,hw
147
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
148
+ w_ = w_ * (int(c) ** (-0.5))
149
+ w_ = torch.nn.functional.softmax(w_, dim=2)
150
+
151
+ # attend to values
152
+ v = v.reshape(b, c, h * w)
153
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
154
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
155
+ h_ = h_.reshape(b, c, h, w)
156
+
157
+ h_ = self.proj_out(h_)
158
+
159
+ return x + h_
160
+
161
+
162
+ class Encoder(nn.Module):
163
+ def __init__(
164
+ self,
165
+ *,
166
+ ch=128,
167
+ out_ch=3,
168
+ ch_mult=(1, 1, 2, 2, 4),
169
+ num_res_blocks=2,
170
+ attn_resolutions=(16,),
171
+ dropout=0.0,
172
+ resamp_with_conv=True,
173
+ in_channels=3,
174
+ resolution=256,
175
+ z_channels=16,
176
+ double_z=True,
177
+ **ignore_kwargs,
178
+ ):
179
+ super().__init__()
180
+ self.ch = ch
181
+ self.temb_ch = 0
182
+ self.num_resolutions = len(ch_mult)
183
+ self.num_res_blocks = num_res_blocks
184
+ self.resolution = resolution
185
+ self.in_channels = in_channels
186
+
187
+ # downsampling
188
+ self.conv_in = torch.nn.Conv2d(
189
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
190
+ )
191
+
192
+ curr_res = resolution
193
+ in_ch_mult = (1,) + tuple(ch_mult)
194
+ self.down = nn.ModuleList()
195
+ for i_level in range(self.num_resolutions):
196
+ block = nn.ModuleList()
197
+ attn = nn.ModuleList()
198
+ block_in = ch * in_ch_mult[i_level]
199
+ block_out = ch * ch_mult[i_level]
200
+ for i_block in range(self.num_res_blocks):
201
+ block.append(
202
+ ResnetBlock(
203
+ in_channels=block_in,
204
+ out_channels=block_out,
205
+ temb_channels=self.temb_ch,
206
+ dropout=dropout,
207
+ )
208
+ )
209
+ block_in = block_out
210
+ if curr_res in attn_resolutions:
211
+ attn.append(AttnBlock(block_in))
212
+ down = nn.Module()
213
+ down.block = block
214
+ down.attn = attn
215
+ if i_level != self.num_resolutions - 1:
216
+ down.downsample = Downsample(block_in, resamp_with_conv)
217
+ curr_res = curr_res // 2
218
+ self.down.append(down)
219
+
220
+ # middle
221
+ self.mid = nn.Module()
222
+ self.mid.block_1 = ResnetBlock(
223
+ in_channels=block_in,
224
+ out_channels=block_in,
225
+ temb_channels=self.temb_ch,
226
+ dropout=dropout,
227
+ )
228
+ self.mid.attn_1 = AttnBlock(block_in)
229
+ self.mid.block_2 = ResnetBlock(
230
+ in_channels=block_in,
231
+ out_channels=block_in,
232
+ temb_channels=self.temb_ch,
233
+ dropout=dropout,
234
+ )
235
+
236
+ # end
237
+ self.norm_out = Normalize(block_in)
238
+ self.conv_out = torch.nn.Conv2d(
239
+ block_in,
240
+ 2 * z_channels if double_z else z_channels,
241
+ kernel_size=3,
242
+ stride=1,
243
+ padding=1,
244
+ )
245
+
246
+ def forward(self, x):
247
+ # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution)
248
+
249
+ # timestep embedding
250
+ temb = None
251
+
252
+ # downsampling
253
+ hs = [self.conv_in(x)]
254
+ for i_level in range(self.num_resolutions):
255
+ for i_block in range(self.num_res_blocks):
256
+ h = self.down[i_level].block[i_block](hs[-1], temb)
257
+ if len(self.down[i_level].attn) > 0:
258
+ h = self.down[i_level].attn[i_block](h)
259
+ hs.append(h)
260
+ if i_level != self.num_resolutions - 1:
261
+ hs.append(self.down[i_level].downsample(hs[-1]))
262
+
263
+ # middle
264
+ h = hs[-1]
265
+ h = self.mid.block_1(h, temb)
266
+ h = self.mid.attn_1(h)
267
+ h = self.mid.block_2(h, temb)
268
+
269
+ # end
270
+ h = self.norm_out(h)
271
+ h = nonlinearity(h)
272
+ h = self.conv_out(h)
273
+ return h
274
+
275
+
276
+ class Decoder(nn.Module):
277
+ def __init__(
278
+ self,
279
+ *,
280
+ ch=128,
281
+ out_ch=3,
282
+ ch_mult=(1, 1, 2, 2, 4),
283
+ num_res_blocks=2,
284
+ attn_resolutions=(),
285
+ dropout=0.0,
286
+ resamp_with_conv=True,
287
+ in_channels=3,
288
+ resolution=256,
289
+ z_channels=16,
290
+ give_pre_end=False,
291
+ **ignore_kwargs,
292
+ ):
293
+ super().__init__()
294
+ self.ch = ch
295
+ self.temb_ch = 0
296
+ self.num_resolutions = len(ch_mult)
297
+ self.num_res_blocks = num_res_blocks
298
+ self.resolution = resolution
299
+ self.in_channels = in_channels
300
+ self.give_pre_end = give_pre_end
301
+
302
+ # compute in_ch_mult, block_in and curr_res at lowest res
303
+ in_ch_mult = (1,) + tuple(ch_mult)
304
+ block_in = ch * ch_mult[self.num_resolutions - 1]
305
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
306
+ self.z_shape = (1, z_channels, curr_res, curr_res)
307
+ print(
308
+ "Working with z of shape {} = {} dimensions.".format(
309
+ self.z_shape, np.prod(self.z_shape)
310
+ )
311
+ )
312
+
313
+ # z to block_in
314
+ self.conv_in = torch.nn.Conv2d(
315
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
316
+ )
317
+
318
+ # middle
319
+ self.mid = nn.Module()
320
+ self.mid.block_1 = ResnetBlock(
321
+ in_channels=block_in,
322
+ out_channels=block_in,
323
+ temb_channels=self.temb_ch,
324
+ dropout=dropout,
325
+ )
326
+ self.mid.attn_1 = AttnBlock(block_in)
327
+ self.mid.block_2 = ResnetBlock(
328
+ in_channels=block_in,
329
+ out_channels=block_in,
330
+ temb_channels=self.temb_ch,
331
+ dropout=dropout,
332
+ )
333
+
334
+ # upsampling
335
+ self.up = nn.ModuleList()
336
+ for i_level in reversed(range(self.num_resolutions)):
337
+ block = nn.ModuleList()
338
+ attn = nn.ModuleList()
339
+ block_out = ch * ch_mult[i_level]
340
+ for i_block in range(self.num_res_blocks + 1):
341
+ block.append(
342
+ ResnetBlock(
343
+ in_channels=block_in,
344
+ out_channels=block_out,
345
+ temb_channels=self.temb_ch,
346
+ dropout=dropout,
347
+ )
348
+ )
349
+ block_in = block_out
350
+ if curr_res in attn_resolutions:
351
+ attn.append(AttnBlock(block_in))
352
+ up = nn.Module()
353
+ up.block = block
354
+ up.attn = attn
355
+ if i_level != 0:
356
+ up.upsample = Upsample(block_in, resamp_with_conv)
357
+ curr_res = curr_res * 2
358
+ self.up.insert(0, up) # prepend to get consistent order
359
+
360
+ # end
361
+ self.norm_out = Normalize(block_in)
362
+ self.conv_out = torch.nn.Conv2d(
363
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
364
+ )
365
+
366
+ def forward(self, z):
367
+ # assert z.shape[1:] == self.z_shape[1:]
368
+ self.last_z_shape = z.shape
369
+
370
+ # timestep embedding
371
+ temb = None
372
+
373
+ # z to block_in
374
+ h = self.conv_in(z)
375
+
376
+ # middle
377
+ h = self.mid.block_1(h, temb)
378
+ h = self.mid.attn_1(h)
379
+ h = self.mid.block_2(h, temb)
380
+
381
+ # upsampling
382
+ for i_level in reversed(range(self.num_resolutions)):
383
+ for i_block in range(self.num_res_blocks + 1):
384
+ h = self.up[i_level].block[i_block](h, temb)
385
+ if len(self.up[i_level].attn) > 0:
386
+ h = self.up[i_level].attn[i_block](h)
387
+ if i_level != 0:
388
+ h = self.up[i_level].upsample(h)
389
+
390
+ # end
391
+ if self.give_pre_end:
392
+ return h
393
+
394
+ h = self.norm_out(h)
395
+ h = nonlinearity(h)
396
+ h = self.conv_out(h)
397
+ return h
398
+
399
+
400
+ class DiagonalGaussianDistribution(object):
401
+ def __init__(self, parameters, deterministic=False):
402
+ self.parameters = parameters
403
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
404
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
405
+ self.deterministic = deterministic
406
+ self.std = torch.exp(0.5 * self.logvar)
407
+ self.var = torch.exp(self.logvar)
408
+ if self.deterministic:
409
+ self.var = self.std = torch.zeros_like(self.mean).to(
410
+ device=self.parameters.device
411
+ )
412
+
413
+ def sample(self):
414
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(
415
+ device=self.parameters.device
416
+ )
417
+ return x
418
+
419
+ def kl(self, other=None):
420
+ if self.deterministic:
421
+ return torch.Tensor([0.0])
422
+ else:
423
+ if other is None:
424
+ return 0.5 * torch.sum(
425
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
426
+ dim=[1, 2, 3],
427
+ )
428
+ else:
429
+ return 0.5 * torch.sum(
430
+ torch.pow(self.mean - other.mean, 2) / other.var
431
+ + self.var / other.var
432
+ - 1.0
433
+ - self.logvar
434
+ + other.logvar,
435
+ dim=[1, 2, 3],
436
+ )
437
+
438
+ def nll(self, sample, dims=[1, 2, 3]):
439
+ if self.deterministic:
440
+ return torch.Tensor([0.0])
441
+ logtwopi = np.log(2.0 * np.pi)
442
+ return 0.5 * torch.sum(
443
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
444
+ dim=dims,
445
+ )
446
+
447
+ def mode(self):
448
+ return self.mean
449
+
450
+
451
+ class AutoencoderKL(nn.Module):
452
+ def __init__(self, embed_dim, ch_mult, use_variational=True, ckpt_path=None):
453
+ super().__init__()
454
+ self.encoder = Encoder(ch_mult=ch_mult, z_channels=embed_dim)
455
+ self.decoder = Decoder(ch_mult=ch_mult, z_channels=embed_dim)
456
+ self.use_variational = use_variational
457
+ mult = 2 if self.use_variational else 1
458
+ self.quant_conv = torch.nn.Conv2d(2 * embed_dim, mult * embed_dim, 1)
459
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, embed_dim, 1)
460
+ self.embed_dim = embed_dim
461
+ if ckpt_path is not None:
462
+ self.init_from_ckpt(ckpt_path)
463
+
464
+ def init_from_ckpt(self, path):
465
+ sd = load_file(path)
466
+ msg = self.load_state_dict(sd, strict=False)
467
+ print("Loading pre-trained KL-VAE")
468
+ print("Missing keys:")
469
+ print(msg.missing_keys)
470
+ print("Unexpected keys:")
471
+ print(msg.unexpected_keys)
472
+ print(f"Restored from {path}")
473
+
474
+ def encode(self, x):
475
+ h = self.encoder(x)
476
+ moments = self.quant_conv(h)
477
+ if not self.use_variational:
478
+ moments = torch.cat((moments, torch.ones_like(moments)), 1)
479
+ posterior = DiagonalGaussianDistribution(moments)
480
+ return posterior
481
+
482
+ def decode(self, z):
483
+ z = self.post_quant_conv(z)
484
+ dec = self.decoder(z)
485
+ return dec
486
+
487
+ def forward(self, inputs, disable=True, train=True, optimizer_idx=0):
488
+ if train:
489
+ return self.training_step(inputs, disable, optimizer_idx)
490
+ else:
491
+ return self.validation_step(inputs, disable)
492
+
493
+
494
+ if __name__ == '__main__':
495
+ from PIL import Image
496
+ import numpy as np
497
+ from torchvision.utils import save_image
498
+
499
+ import torch
500
+ import torch.nn as nn
501
+ import torch.nn.functional as F
502
+
503
+ def fsq_quantize(z, L=16, mode='sigmoid', eps=1e-6):
504
+ if mode == 'sigmoid':
505
+ z_norm = torch.sigmoid(z)
506
+ elif mode == 'tanh':
507
+ z_norm = (torch.tanh(z) + 1) / 2
508
+
509
+ z_scaled = z_norm * (L - 1)
510
+ z_rounded = torch.round(z_scaled)
511
+ z_quant = z_scaled + (z_rounded - z_scaled).detach()
512
+ z_quant = z_quant / (L - 1)
513
+
514
+ z_quant = torch.clamp(z_quant, eps, 1 - eps)
515
+
516
+ if mode == 'sigmoid':
517
+ z_rev = torch.logit(z_quant)
518
+ elif mode == 'tanh':
519
+ z_rev = torch.atanh((z_quant * 2) - 1)
520
+ return z_rev
521
+
522
+ vae_checkpoint_path = 'kl16.safetensors'
523
+ vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path=vae_checkpoint_path)
524
+ vae = vae.to('cuda').eval()
525
+ img_path = '/opt/tiger/LARP/jdb/data_schnell_step4_run1/000/0a0ef78f-11c2-4eaf-8f93-cd85e7d12b4d.jpg.jpg'
526
+ raw_img = Image.open(img_path).convert('RGB')
527
+ img = (np.array(raw_img) / 255. - 0.5) / 0.5
528
+ img = torch.tensor(img).cuda().float().permute(2, 0, 1).unsqueeze(0)
529
+
530
+ posterior = vae.encode(img)
531
+ x = posterior.sample().mul_(0.2325)
532
+ dec = vae.decode(x / 0.2325)
533
+
534
+ dec_quants = []
535
+ for L in [2, 4, 16, 32, 64, 128, 256]:
536
+ with torch.no_grad():
537
+ x_quant = fsq_quantize(x, L, mode='sigmoid')
538
+ # x_quant = fsq_quantize(x, L, mode='tanh')
539
+ dec_quant = vae.decode(x_quant / 0.2325)
540
+ dec_quants.append(dec_quant)
541
+
542
+ out_img = torch.cat([img, dec]+dec_quants, dim=0)
543
+ save_image(out_img, 'rec.png', nrow=3, normalize=True, value_range=(-1, 1))
544
+
545
+ import pdb;pdb.set_trace()
546
+