javierabad01 commited on
Commit
9e2ef13
·
verified ·
1 Parent(s): 53d646f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -0
app.py CHANGED
@@ -3,6 +3,7 @@ from PIL import Image
3
  import torch
4
  import torchvision.transforms as transforms
5
  import numpy as np
 
6
 
7
  from archs.model import UNet
8
 
@@ -28,14 +29,26 @@ def load_img (filename):
28
  img_tensor = pil_to_tensor(img)
29
  return img_tensor
30
 
 
 
 
 
 
 
 
 
 
31
  def process_img(image):
32
  img = np.array(image)
33
  img = img / 255.
34
  img = img.astype(np.float32)
35
  y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
 
 
36
 
37
  with torch.no_grad():
38
  x_hat = model(y)
 
39
 
40
  restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
41
  restored_img = np.clip(restored_img, 0. , 1.)
 
3
  import torch
4
  import torchvision.transforms as transforms
5
  import numpy as np
6
+ import torch.nn.functional as F
7
 
8
  from archs.model import UNet
9
 
 
29
  img_tensor = pil_to_tensor(img)
30
  return img_tensor
31
 
32
+
33
+ def check_image_size(x):
34
+ _, _, h, w = x.size()
35
+ mod_pad_h = (32 - h % 32) % 32
36
+ mod_pad_w = (32 - w % 32) % 32
37
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), value = 0)
38
+ return x
39
+
40
+
41
  def process_img(image):
42
  img = np.array(image)
43
  img = img / 255.
44
  img = img.astype(np.float32)
45
  y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
46
+ _, _, H, W = y.shape
47
+ y= check_image_size(y)
48
 
49
  with torch.no_grad():
50
  x_hat = model(y)
51
+ x_hat = x_hat[:, :, :H, :W]
52
 
53
  restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
54
  restored_img = np.clip(restored_img, 0. , 1.)