msong97 commited on
Commit
92c07f0
·
1 Parent(s): 1501900

[Feat] Apply methods on user image

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -18,14 +18,31 @@ DEVICE_STR = 'cuda'
18
 
19
 
20
  ### Gradio Utils
21
- def generate_imgs(dataset: EvalDataset, idx: int,
22
- model: EvalModel, baseline: BaselineModel,
23
- physics: PhysicsWithGenerator, use_gen: bool,
24
- metrics: List[Metric]):
 
25
  ### Load 1 image
26
  x = dataset[idx] # shape : (3, 256, 256)
27
  x = x.unsqueeze(0) # shape : (1, 3, 256, 256)
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with torch.no_grad():
30
  ### Compute y
31
  y = physics(x, use_gen) # possible reduction in img shape due to Blurring
@@ -71,23 +88,18 @@ def generate_imgs(dataset: EvalDataset, idx: int,
71
  out = to_pil(process_img(out)[0].to('cpu'))
72
  out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))
73
 
74
-
75
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
76
 
77
- def update_random_idx_and_generate_imgs(dataset: EvalDataset,
78
  model: EvalModel,
79
  baseline: BaselineModel,
80
  physics: PhysicsWithGenerator,
81
  use_gen: bool,
82
  metrics: List[Metric]):
83
  idx = random.randint(0, len(dataset)-1)
84
- x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
85
- idx,
86
- model,
87
- baseline,
88
- physics,
89
- use_gen,
90
- metrics)
91
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
92
 
93
 
@@ -151,7 +163,7 @@ with gr.Blocks(title=title, css=custom_css) as interface:
151
  with gr.Column():
152
  with gr.Row():
153
  with gr.Column():
154
- clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=False)
155
  physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params())
156
  with gr.Column():
157
  y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
@@ -189,9 +201,10 @@ with gr.Blocks(title=title, css=custom_css) as interface:
189
  value=metric_names,
190
  label="Choose metrics you are interested")
191
  use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
 
192
  with gr.Column(scale=1):
193
- load_button = gr.Button("Load images...")
194
- load_random_button = gr.Button("Load randomly...")
195
 
196
  ### Event listeners
197
  choose_dataset.change(fn=get_dataset,
@@ -204,7 +217,15 @@ with gr.Blocks(title=title, css=custom_css) as interface:
204
  choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
205
  inputs=choose_metrics,
206
  outputs=metrics_placeholder)
207
- load_button.click(fn=generate_imgs,
 
 
 
 
 
 
 
 
208
  inputs=[dataset_placeholder,
209
  idx_slider,
210
  model_a_placeholder,
@@ -213,7 +234,7 @@ with gr.Blocks(title=title, css=custom_css) as interface:
213
  use_generator_button,
214
  metrics_placeholder],
215
  outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
216
- load_random_button.click(fn=update_random_idx_and_generate_imgs,
217
  inputs=[dataset_placeholder,
218
  model_a_placeholder,
219
  model_b_placeholder,
 
18
 
19
 
20
  ### Gradio Utils
21
+
22
+ def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
23
+ model: EvalModel, baseline: BaselineModel,
24
+ physics: PhysicsWithGenerator, use_gen: bool,
25
+ metrics: List[Metric]):
26
  ### Load 1 image
27
  x = dataset[idx] # shape : (3, 256, 256)
28
  x = x.unsqueeze(0) # shape : (1, 3, 256, 256)
29
 
30
+ return generate_imgs(x, model, baseline, physics, use_gen, metrics)
31
+
32
+ def generate_imgs_from_user(image,
33
+ model: EvalModel, baseline: BaselineModel,
34
+ physics: PhysicsWithGenerator, use_gen: bool,
35
+ metrics: List[Metric]):
36
+ # PIL image -> torch.Tensor
37
+ x = transforms.ToTensor()(image).unsqueeze(0)
38
+
39
+ return generate_imgs(x, model, baseline, physics, use_gen, metrics)
40
+
41
+ def generate_imgs(x: torch.Tensor,
42
+ model: EvalModel, baseline: BaselineModel,
43
+ physics: PhysicsWithGenerator, use_gen: bool,
44
+ metrics: List[Metric]):
45
+
46
  with torch.no_grad():
47
  ### Compute y
48
  y = physics(x, use_gen) # possible reduction in img shape due to Blurring
 
88
  out = to_pil(process_img(out)[0].to('cpu'))
89
  out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))
90
 
 
91
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
92
 
93
+ def generate_random_imgs_from_dataset(dataset: EvalDataset,
94
  model: EvalModel,
95
  baseline: BaselineModel,
96
  physics: PhysicsWithGenerator,
97
  use_gen: bool,
98
  metrics: List[Metric]):
99
  idx = random.randint(0, len(dataset)-1)
100
+ x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
101
+ dataset, idx, model, baseline, physics, use_gen, metrics
102
+ )
 
 
 
 
103
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
104
 
105
 
 
163
  with gr.Column():
164
  with gr.Row():
165
  with gr.Column():
166
+ clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=True)
167
  physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params())
168
  with gr.Column():
169
  y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
 
201
  value=metric_names,
202
  label="Choose metrics you are interested")
203
  use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
204
+ run_button = gr.Button("Run current image")
205
  with gr.Column(scale=1):
206
+ load_button = gr.Button("Load images from dataset...")
207
+ load_random_button = gr.Button("Load randomly from dataset...")
208
 
209
  ### Event listeners
210
  choose_dataset.change(fn=get_dataset,
 
217
  choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
218
  inputs=choose_metrics,
219
  outputs=metrics_placeholder)
220
+ run_button.click(fn=generate_imgs_from_user,
221
+ inputs=[clean,
222
+ model_a_placeholder,
223
+ model_b_placeholder,
224
+ physics_placeholder,
225
+ use_generator_button,
226
+ metrics_placeholder],
227
+ outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
228
+ load_button.click(fn=generate_imgs_from_dataset,
229
  inputs=[dataset_placeholder,
230
  idx_slider,
231
  model_a_placeholder,
 
234
  use_generator_button,
235
  metrics_placeholder],
236
  outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
237
+ load_random_button.click(fn=generate_random_imgs_from_dataset,
238
  inputs=[dataset_placeholder,
239
  model_a_placeholder,
240
  model_b_placeholder,