msong97 commited on
Commit
499f595
·
1 Parent(s): 6361f4a

Define global variables across all users -> reduce loading time

Browse files
Files changed (1) hide show
  1. app.py +50 -55
app.py CHANGED
@@ -16,50 +16,52 @@ from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDatase
16
 
17
 
18
  ### Config
19
- DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu' # run model inference on NVIDIA gpu
20
- torch.set_grad_enabled(False) # stops tracking values for gradients
 
 
21
 
22
 
23
  ### Gradio Utils
24
  def generate_imgs_from_user(image,
25
- model: EvalModel, baseline: BaselineModel,
26
  physics: PhysicsWithGenerator, use_gen: bool,
 
27
  metrics: List[Metric]):
 
28
  if image is None:
29
  return None, None, None, None, None, None, None, None
30
 
31
- # PIL image -> torch.Tensor
32
  x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
33
 
34
- return generate_imgs(x, model, baseline, physics, use_gen, metrics)
35
 
36
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
37
- model: EvalModel, baseline: BaselineModel,
38
  physics: PhysicsWithGenerator, use_gen: bool,
 
39
  metrics: List[Metric]):
40
  ### Load 1 image
41
  x = dataset[idx] # shape : (C, H, W)
42
  x = x.unsqueeze(0) # shape : (1, C, H, W)
43
 
44
- return generate_imgs(x, model, baseline, physics, use_gen, metrics)
45
 
46
  def generate_random_imgs_from_dataset(dataset: EvalDataset,
47
- model: EvalModel,
48
- baseline: BaselineModel,
49
  physics: PhysicsWithGenerator,
50
  use_gen: bool,
 
 
51
  metrics: List[Metric]):
52
  idx = random.randint(0, len(dataset)-1)
53
  x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
54
- dataset, idx, model, baseline, physics, use_gen, metrics
55
  )
56
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
57
 
58
  def generate_imgs(x: torch.Tensor,
59
- model: EvalModel, baseline: BaselineModel,
60
  physics: PhysicsWithGenerator, use_gen: bool,
 
61
  metrics: List[Metric]):
62
-
63
  ### Compute y
64
  y = physics(x, use_gen) # possible reduction in img shape due to Blurring
65
 
@@ -114,11 +116,9 @@ def generate_imgs(x: torch.Tensor,
114
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
115
 
116
 
117
- get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
118
- get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
119
- get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
120
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
121
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
 
122
 
123
  def get_dataset(dataset_name):
124
  if dataset_name == 'MRI':
@@ -142,37 +142,42 @@ def get_dataset(dataset_name):
142
  return dataset, idx, physics, baseline, available_physics
143
 
144
 
 
 
 
 
 
 
 
 
 
145
  ### Gradio Blocks interface
146
 
147
  title = "Inverse problem playground" # displayed on gradio tab and in the gradio page
148
  with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
149
  gr.Markdown("## " + title)
150
 
151
- ### DEFAULT VALUES
152
- # Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
153
- # Solution: using lambda expression
154
- model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", ""))
155
- model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
156
- metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
157
-
158
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
159
- physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
160
  available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
161
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
 
 
 
 
162
 
163
 
164
- ### LAYOUT
165
- # Display images
166
- with gr.Row():
167
- gt_img = gr.Image(label="Ground-truth image", interactive=True)
168
- observed_img = gr.Image(label="Observed image", interactive=False)
169
- model_a_out = gr.Image(label="RAM output", interactive=False)
170
- model_b_out = gr.Image(label="DPIR output", interactive=False)
171
-
172
  @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
173
  def dynamic_layout(dataset, physics, available_physics):
174
  ### LAYOUT
175
 
 
 
 
 
 
 
 
176
  # Manage datasets and display metric values
177
  with gr.Row():
178
  with gr.Column(scale=1, min_width=160):
@@ -180,19 +185,16 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
180
  choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
181
  label="Datasets",
182
  value=dataset.name)
183
- idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=0)
184
  with gr.Row():
185
  load_button = gr.Button("Run on index image from dataset", size='md')
186
  load_random_button = gr.Button("Run on random image from dataset", size='md')
187
  with gr.Column(scale=1, min_width=160):
188
- observed_metrics = gr.Textbox(label="PSNR Observed",
189
- lines=1)
190
  with gr.Column(scale=1, min_width=160):
191
- out_a_metric = gr.Textbox(label="PSNR RAM output",
192
- lines=1)
193
  with gr.Column(scale=1, min_width=160):
194
- out_b_metric = gr.Textbox(label="PSNR DPIR",
195
- lines=1)
196
 
197
  # Manage physics
198
  with gr.Row():
@@ -200,19 +202,18 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
200
  choose_physics = gr.Radio(choices=available_physics,
201
  label="Physics",
202
  value=physics.name)
203
- use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True)
204
  with gr.Column(scale=1):
205
  with gr.Row():
206
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
207
- label="Updatable Parameter Key")
208
  value_text = gr.Textbox(label="Update Value")
209
- update_button = gr.Button("Manually update parameter value")
210
  with gr.Column(scale=2):
211
  physics_params = gr.Textbox(label="Physics parameters",
212
  lines=5,
213
  value=physics.display_saved_params())
214
 
215
-
216
  ### Event listeners
217
 
218
  choose_dataset.change(fn=get_dataset,
@@ -223,34 +224,28 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
223
  outputs=[physics_placeholder])
224
  update_button.click(fn=physics.update_and_display_params,
225
  inputs=[key_selector, value_text], outputs=physics_params)
226
- run_button.click(fn=generate_imgs_from_user,
227
  inputs=[gt_img,
228
- model_a_placeholder,
229
- model_b_placeholder,
230
  physics_placeholder,
231
  use_generator_button,
232
- metrics_placeholder],
233
  outputs=[gt_img, observed_img, model_a_out, model_b_out,
234
  physics_params, observed_metrics, out_a_metric, out_b_metric])
235
- load_button.click(fn=generate_imgs_from_dataset,
236
  inputs=[dataset_placeholder,
237
  idx_slider,
238
- model_a_placeholder,
239
- model_b_placeholder,
240
  physics_placeholder,
241
  use_generator_button,
242
- metrics_placeholder],
243
  outputs=[gt_img, observed_img, model_a_out, model_b_out,
244
  physics_params, observed_metrics, out_a_metric, out_b_metric])
245
- load_random_button.click(fn=generate_random_imgs_from_dataset,
246
  inputs=[dataset_placeholder,
247
- model_a_placeholder,
248
- model_b_placeholder,
249
  physics_placeholder,
250
  use_generator_button,
251
- metrics_placeholder],
252
  outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
253
  physics_params, observed_metrics, out_a_metric, out_b_metric])
254
 
255
 
256
- interface.launch()
 
16
 
17
 
18
  ### Config
19
+ # run model inference on NVIDIA gpu if available
20
+ DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+ # stops tracking values for gradients
22
+ torch.set_grad_enabled(False)
23
 
24
 
25
  ### Gradio Utils
26
  def generate_imgs_from_user(image,
 
27
  physics: PhysicsWithGenerator, use_gen: bool,
28
+ baseline: BaselineModel, model: EvalModel,
29
  metrics: List[Metric]):
30
+ # Happens when user image is missing
31
  if image is None:
32
  return None, None, None, None, None, None, None, None
33
 
34
+ # PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR
35
  x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
36
 
37
+ return generate_imgs(x, physics, use_gen, baseline, model, metrics)
38
 
39
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
 
40
  physics: PhysicsWithGenerator, use_gen: bool,
41
+ baseline: BaselineModel, model: EvalModel,
42
  metrics: List[Metric]):
43
  ### Load 1 image
44
  x = dataset[idx] # shape : (C, H, W)
45
  x = x.unsqueeze(0) # shape : (1, C, H, W)
46
 
47
+ return generate_imgs(x, physics, use_gen, baseline, model, metrics)
48
 
49
  def generate_random_imgs_from_dataset(dataset: EvalDataset,
 
 
50
  physics: PhysicsWithGenerator,
51
  use_gen: bool,
52
+ baseline: BaselineModel,
53
+ model: EvalModel,
54
  metrics: List[Metric]):
55
  idx = random.randint(0, len(dataset)-1)
56
  x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
57
+ dataset, idx, physics, use_gen, baseline, model, metrics
58
  )
59
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
60
 
61
  def generate_imgs(x: torch.Tensor,
 
62
  physics: PhysicsWithGenerator, use_gen: bool,
63
+ baseline: BaselineModel, model: EvalModel,
64
  metrics: List[Metric]):
 
65
  ### Compute y
66
  y = physics(x, use_gen) # possible reduction in img shape due to Blurring
67
 
 
116
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
117
 
118
 
 
 
 
119
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
120
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
121
+ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
122
 
123
  def get_dataset(dataset_name):
124
  if dataset_name == 'MRI':
 
142
  return dataset, idx, physics, baseline, available_physics
143
 
144
 
145
+ # global variables shared by all users
146
+ ram_model = EvalModel("unext_emb_physics_config_C", device_str=DEVICE_STR)
147
+ psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)
148
+
149
+ generate_imgs_from_user_partial = partial(generate_imgs_from_user, model=ram_model, metrics=psnr)
150
+ generate_imgs_from_dataset_partial = partial(generate_imgs_from_dataset, model=ram_model, metrics=psnr)
151
+ generate_random_imgs_from_dataset_partial = partial(generate_random_imgs_from_dataset, model=ram_model, metrics=psnr)
152
+
153
+
154
  ### Gradio Blocks interface
155
 
156
  title = "Inverse problem playground" # displayed on gradio tab and in the gradio page
157
  with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
158
  gr.Markdown("## " + title)
159
 
160
+ ### USER-SPECIFIC VARIABLES
 
 
 
 
 
 
161
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
 
162
  available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
163
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
164
+ # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
165
+ # Solution: using lambda expression
166
+ physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
167
+ model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
168
 
169
 
 
 
 
 
 
 
 
 
170
  @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
171
  def dynamic_layout(dataset, physics, available_physics):
172
  ### LAYOUT
173
 
174
+ # Display images
175
+ with gr.Row():
176
+ gt_img = gr.Image(label="Ground-truth image", interactive=True, key=0)
177
+ observed_img = gr.Image(label="Observed image", interactive=False, key=1)
178
+ model_a_out = gr.Image(label="RAM output", interactive=False, key=2)
179
+ model_b_out = gr.Image(label="DPIR output", interactive=False, key=3)
180
+
181
  # Manage datasets and display metric values
182
  with gr.Row():
183
  with gr.Column(scale=1, min_width=160):
 
185
  choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
186
  label="Datasets",
187
  value=dataset.name)
188
+ idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=4)
189
  with gr.Row():
190
  load_button = gr.Button("Run on index image from dataset", size='md')
191
  load_random_button = gr.Button("Run on random image from dataset", size='md')
192
  with gr.Column(scale=1, min_width=160):
193
+ observed_metrics = gr.Textbox(label="Observed metric", lines=1, key=5)
 
194
  with gr.Column(scale=1, min_width=160):
195
+ out_a_metric = gr.Textbox(label="RAM output metrics", lines=1, key=6)
 
196
  with gr.Column(scale=1, min_width=160):
197
+ out_b_metric = gr.Textbox(label="DPIR output metrics", lines=1, key=7)
 
198
 
199
  # Manage physics
200
  with gr.Row():
 
202
  choose_physics = gr.Radio(choices=available_physics,
203
  label="Physics",
204
  value=physics.name)
205
+ use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key=8)
206
  with gr.Column(scale=1):
207
  with gr.Row():
208
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
209
+ label="Updatable Key")
210
  value_text = gr.Textbox(label="Update Value")
211
+ update_button = gr.Button("Manually update parameter value", size='md')
212
  with gr.Column(scale=2):
213
  physics_params = gr.Textbox(label="Physics parameters",
214
  lines=5,
215
  value=physics.display_saved_params())
216
 
 
217
  ### Event listeners
218
 
219
  choose_dataset.change(fn=get_dataset,
 
224
  outputs=[physics_placeholder])
225
  update_button.click(fn=physics.update_and_display_params,
226
  inputs=[key_selector, value_text], outputs=physics_params)
227
+ run_button.click(fn=generate_imgs_from_user_partial,
228
  inputs=[gt_img,
 
 
229
  physics_placeholder,
230
  use_generator_button,
231
+ model_b_placeholder],
232
  outputs=[gt_img, observed_img, model_a_out, model_b_out,
233
  physics_params, observed_metrics, out_a_metric, out_b_metric])
234
+ load_button.click(fn=generate_imgs_from_dataset_partial,
235
  inputs=[dataset_placeholder,
236
  idx_slider,
 
 
237
  physics_placeholder,
238
  use_generator_button,
239
+ model_b_placeholder],
240
  outputs=[gt_img, observed_img, model_a_out, model_b_out,
241
  physics_params, observed_metrics, out_a_metric, out_b_metric])
242
+ load_random_button.click(fn=generate_random_imgs_from_dataset_partial,
243
  inputs=[dataset_placeholder,
 
 
244
  physics_placeholder,
245
  use_generator_button,
246
+ model_b_placeholder],
247
  outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
248
  physics_params, observed_metrics, out_a_metric, out_b_metric])
249
 
250
 
251
+ interface.launch()