msong97 commited on
Commit
5c8c5d6
·
1 Parent(s): e35108e

Load automatically the right baseline and default physics when choosing a dataset

Browse files
Files changed (1) hide show
  1. app.py +11 -13
app.py CHANGED
@@ -112,25 +112,23 @@ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
112
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
113
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
114
 
115
- def get_physics(physics_name):
116
- if physics_name == 'MRI':
117
- baseline = get_baseline_model_on_DEVICE_STR('DPIR_MRI')
118
- elif physics_name == 'CT':
119
- baseline = get_baseline_model_on_DEVICE_STR('DPIR_CT')
120
- else:
121
- baseline = get_baseline_model_on_DEVICE_STR('DPIR')
122
- return get_physics_on_DEVICE_STR(physics_name), baseline
123
-
124
  AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
125
  def get_dataset(dataset_name):
126
  global AVAILABLE_PHYSICS
127
  if dataset_name == 'MRI':
128
  AVAILABLE_PHYSICS = ['MRI']
 
 
129
  elif dataset_name == 'CT':
130
  AVAILABLE_PHYSICS = ['CT']
 
 
131
  else:
132
  AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
133
- return get_dataset_on_DEVICE_STR(dataset_name)
 
 
 
134
 
135
  ### Gradio Blocks interface
136
 
@@ -212,10 +210,10 @@ with gr.Blocks(title=title, css=custom_css) as interface:
212
  ### Event listeners
213
  choose_dataset.change(fn=get_dataset,
214
  inputs=choose_dataset,
215
- outputs=dataset_placeholder)
216
- choose_physics.change(fn=get_physics,
217
  inputs=choose_physics,
218
- outputs=[physics_placeholder, model_b_placeholder])
219
  update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
220
  choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
221
  inputs=choose_metrics,
 
112
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
113
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
114
 
 
 
 
 
 
 
 
 
 
115
  AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
116
  def get_dataset(dataset_name):
117
  global AVAILABLE_PHYSICS
118
  if dataset_name == 'MRI':
119
  AVAILABLE_PHYSICS = ['MRI']
120
+ baseline_name = 'DPIR_MRI'
121
+ physics_name = 'MRI'
122
  elif dataset_name == 'CT':
123
  AVAILABLE_PHYSICS = ['CT']
124
+ baseline_name = 'DPIR_CT'
125
+ physics_name = 'CT'
126
  else:
127
  AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
128
+ baseline_name = 'DPIR'
129
+ physics_name = 'MotionBlur_easy'
130
+ return get_dataset_on_DEVICE_STR(dataset_name), get_physics_on_DEVICE_STR(physics_name), get_baseline_model_on_DEVICE_STR(baseline_name)
131
+
132
 
133
  ### Gradio Blocks interface
134
 
 
210
  ### Event listeners
211
  choose_dataset.change(fn=get_dataset,
212
  inputs=choose_dataset,
213
+ outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
214
+ choose_physics.change(fn=get_physics_on_DEVICE_STR,
215
  inputs=choose_physics,
216
+ outputs=[physics_placeholder])
217
  update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
218
  choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
219
  inputs=choose_metrics,