artelabsuper commited on
Commit
053b94b
·
1 Parent(s): 4216279

input scale selectable

Browse files
Files changed (1) hide show
  1. app.py +8 -5
app.py CHANGED
@@ -11,6 +11,7 @@ from models.modelNetB import Generator as GB
11
  from models.modelNetC import Generator as GC
12
 
13
  scale_size = 128
 
14
  # load model
15
  modeltype2path = {
16
  'ModelA': 'DTM_exp_train10%_model_a/g-best.pth',
@@ -33,8 +34,9 @@ preprocess = transforms.Compose([
33
  transforms.ToTensor()
34
  ])
35
 
36
- def predict(input_image, model_name):
37
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
 
38
  # transform image to torch and do preprocessing
39
  torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
40
  torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
@@ -61,8 +63,9 @@ def predict(input_image, model_name):
61
  iface = gr.Interface(
62
  fn=predict,
63
  inputs=[
64
- gr.Image(shape=(scale_size,scale_size)),
65
- gr.inputs.Radio(MODELS_TYPE)
 
66
  ],
67
  outputs=[
68
  gr.Text(label='Model info'),
@@ -70,9 +73,9 @@ iface = gr.Interface(
70
  gr.Image(label='DTM')
71
  ],
72
  examples=[
73
- [f"demo_imgs/{name}", MODELS_TYPE[0]] for name in os.listdir('demo_imgs')
74
  ],
75
  title="Super Resolution and DTM Estimation",
76
- description=f"This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it, for demo reason input is scale to {scale_size}x{scale_size})."
77
  )
78
  iface.launch()
 
11
  from models.modelNetC import Generator as GC
12
 
13
  scale_size = 128
14
+ scale_sizes = [128, 256, 512]
15
  # load model
16
  modeltype2path = {
17
  'ModelA': 'DTM_exp_train10%_model_a/g-best.pth',
 
34
  transforms.ToTensor()
35
  ])
36
 
37
+ def predict(input_image, model_name, input_scale_factor):
38
  pil_image = Image.fromarray(input_image.astype('uint8'), 'RGB')
39
+ pil_image = transforms.Resize((input_scale_factor, input_scale_factor))(pil_image)
40
  # transform image to torch and do preprocessing
41
  torch_img = preprocess(pil_image).to(DEVICE).unsqueeze(0).to(DEVICE)
42
  torch_img = (torch_img - torch.min(torch_img)) / (torch.max(torch_img) - torch.min(torch_img))
 
63
  iface = gr.Interface(
64
  fn=predict,
65
  inputs=[
66
+ gr.Image(),
67
+ gr.inputs.Radio(MODELS_TYPE),
68
+ gr.inputs.Radio(scale_sizes)
69
  ],
70
  outputs=[
71
  gr.Text(label='Model info'),
 
73
  gr.Image(label='DTM')
74
  ],
75
  examples=[
76
+ [f"demo_imgs/{name}", MODELS_TYPE[0], 128] for name in os.listdir('demo_imgs')
77
  ],
78
  title="Super Resolution and DTM Estimation",
79
+ description=f"This demo predict Super Resolution and (Super Resolution) DTM from a Grayscale image (if RGB we convert it)."
80
  )
81
  iface.launch()