HaoFeng2019 commited on
Commit
1b147a3
·
1 Parent(s): 980de13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -66,8 +66,6 @@ class GeoTr_Seg(nn.Module):
66
  return bm
67
 
68
 
69
-
70
-
71
  # Initialize models
72
  GeoTr_Seg_model = GeoTr_Seg()
73
  #IllTr_model = IllTr()
@@ -83,27 +81,27 @@ GeoTr_Seg_model = torch.compile(GeoTr_Seg_model)
83
 
84
  def process_image(input_image):
85
  GeoTr_Seg_model.eval()
86
- #IllTr_model.eval()
87
 
88
  im_ori = np.array(input_image)[:, :, :3] / 255.
89
  h, w, _ = im_ori.shape
90
- im = cv2.resize(im_ori, (288, 288))
 
91
  im = im.transpose(2, 0, 1)
92
  im = torch.from_numpy(im).float().unsqueeze(0)
93
 
94
  with torch.no_grad():
95
  bm = GeoTr_Seg_model(im)
96
  bm = bm.cpu()
97
- bm0 = cv2.resize(bm[0, 0].numpy(), (w, h))
98
- bm1 = cv2.resize(bm[0, 1].numpy(), (w, h))
99
  bm0 = cv2.blur(bm0, (3, 3))
100
  bm1 = cv2.blur(bm1, (3, 3))
101
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
102
 
103
  out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
104
  img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
105
-
106
- ill_rec=False
107
 
108
  if ill_rec:
109
  img_ill = rec_ill(IllTr_model, img_geo)
@@ -111,6 +109,7 @@ def process_image(input_image):
111
  else:
112
  return Image.fromarray(img_geo)
113
 
 
114
  # Define Gradio interface
115
  input_image = gr.inputs.Image()
116
  output_image = gr.outputs.Image(type='pil')
 
66
  return bm
67
 
68
 
 
 
69
  # Initialize models
70
  GeoTr_Seg_model = GeoTr_Seg()
71
  #IllTr_model = IllTr()
 
81
 
82
  def process_image(input_image):
83
  GeoTr_Seg_model.eval()
 
84
 
85
  im_ori = np.array(input_image)[:, :, :3] / 255.
86
  h, w, _ = im_ori.shape
87
+ new_height = int(h * (288 / w))
88
+ im = cv2.resize(im_ori, (288, new_height))
89
  im = im.transpose(2, 0, 1)
90
  im = torch.from_numpy(im).float().unsqueeze(0)
91
 
92
  with torch.no_grad():
93
  bm = GeoTr_Seg_model(im)
94
  bm = bm.cpu()
95
+ bm0 = cv2.resize(bm[0, 0].numpy(), (288, new_height))
96
+ bm1 = cv2.resize(bm[0, 1].numpy(), (288, new_height))
97
  bm0 = cv2.blur(bm0, (3, 3))
98
  bm1 = cv2.blur(bm1, (3, 3))
99
  lbl = torch.from_numpy(np.stack([bm0, bm1], axis=2)).unsqueeze(0)
100
 
101
  out = F.grid_sample(torch.from_numpy(im_ori).permute(2, 0, 1).unsqueeze(0).float(), lbl, align_corners=True)
102
  img_geo = ((out[0] * 255).permute(1, 2, 0).numpy()).astype(np.uint8)
103
+
104
+ ill_rec = False
105
 
106
  if ill_rec:
107
  img_ill = rec_ill(IllTr_model, img_geo)
 
109
  else:
110
  return Image.fromarray(img_geo)
111
 
112
+
113
  # Define Gradio interface
114
  input_image = gr.inputs.Image()
115
  output_image = gr.outputs.Image(type='pil')