vobecant commited on
Commit
2762176
·
1 Parent(s): 435cc18

Initial commit.

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -14,9 +14,10 @@ from segmenter_model.utils import colorize_one, map2cs
14
  # WEIGHTS = './weights/segmenter.pth
15
  WEIGHTS = './weights/segmenter_nusc.pth'
16
  FULL = True
 
17
 
18
 
19
- def blend_images(bg, fg, alpha=0.3):
20
  fg = fg.convert('RGBA')
21
  bg = bg.convert('RGBA')
22
  blended = Image.blend(bg, fg, alpha=alpha)
@@ -135,10 +136,12 @@ download_weights()
135
  model, window_size, window_stride, im_size = create_model()
136
 
137
 
138
- def get_transformations():
139
  trans_list = [transforms.ToTensor()]
140
 
141
- if im_size != 1024:
 
 
142
  trans_list.append(transforms.Resize(im_size))
143
 
144
  trans_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
@@ -148,7 +151,7 @@ def get_transformations():
148
 
149
  def predict(input_img, cs_mapping):
150
  input_img_pil = Image.open(input_img)
151
- transform = get_transformations()
152
  input_img = transform(input_img_pil)
153
  input_img = torch.unsqueeze(input_img, 0)
154
 
@@ -186,7 +189,7 @@ examples = [['examples/img5.jpeg', True], ['examples/100.jpeg', True], ['example
186
 
187
  iface = gr.Interface(predict, [gr.inputs.Image(type='filepath'), gr.inputs.Checkbox(label="Cityscapes mapping")],
188
  "image", title=title, description=description,
189
- examples=examples)
190
  # iface = gr.Interface(predict, gr.inputs.Image(type='filepath'),
191
  # "image", title=title, description=description,
192
  # examples=examples)
 
14
  # WEIGHTS = './weights/segmenter.pth
15
  WEIGHTS = './weights/segmenter_nusc.pth'
16
  FULL = True
17
+ ALPHA = 0.5
18
 
19
 
20
+ def blend_images(bg, fg, alpha=ALPHA):
21
  fg = fg.convert('RGBA')
22
  bg = bg.convert('RGBA')
23
  blended = Image.blend(bg, fg, alpha=alpha)
 
136
  model, window_size, window_stride, im_size = create_model()
137
 
138
 
139
+ def get_transformations(input_img):
140
  trans_list = [transforms.ToTensor()]
141
 
142
+ shorter_input_size = min(input_img.size)
143
+
144
+ if im_size != 1024 or shorter_input_size < im_size:
145
  trans_list.append(transforms.Resize(im_size))
146
 
147
  trans_list.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
 
151
 
152
  def predict(input_img, cs_mapping):
153
  input_img_pil = Image.open(input_img)
154
+ transform = get_transformations(input_img)
155
  input_img = transform(input_img_pil)
156
  input_img = torch.unsqueeze(input_img, 0)
157
 
 
189
 
190
  iface = gr.Interface(predict, [gr.inputs.Image(type='filepath'), gr.inputs.Checkbox(label="Cityscapes mapping")],
191
  "image", title=title, description=description,
192
+ examples=examples, allow_screenshot=True)
193
  # iface = gr.Interface(predict, gr.inputs.Image(type='filepath'),
194
  # "image", title=title, description=description,
195
  # examples=examples)