msong97 commited on
Commit
1501900
·
1 Parent(s): 07bbfe7

Fix syntax error and add dynamic change of available physics depending on dataset

Browse files
Files changed (1) hide show
  1. app.py +5 -45
app.py CHANGED
@@ -90,39 +90,6 @@ def update_random_idx_and_generate_imgs(dataset: EvalDataset,
90
  metrics)
91
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
92
 
93
- def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator,
94
- model_a: EvalModel | BaselineModel, model_b: EvalModel | BaselineModel,
95
- x: Image.Image, y: Image.Image,
96
- out_a: Image.Image, out_b: Image.Image,
97
- y_metrics_str: str,
98
- out_a_metric_str : str, out_b_metric_str: str) -> None:
99
-
100
- ### PROCESSES STR
101
- physics_params_str = ""
102
- for param_name, param_value in physics.saved_params["updatable_params"].items():
103
- physics_params_str += f"{param_name}_{param_value}-"
104
- physics_params_str = physics_params_str[:-1] if physics_params_str.endswith("-") else physics_params_str
105
- y_metrics_str = y_metrics_str.replace(" = ", "_").replace("\n", "-")
106
- y_metrics_str = y_metrics_str[:-1] if y_metrics_str.endswith("-") else y_metrics_str
107
- out_a_metric_str = out_a_metric_str.replace(" = ", "_").replace("\n", "-")
108
- out_a_metric_str = out_a_metric_str[:-1] if out_a_metric_str.endswith("-") else out_a_metric_str
109
- out_b_metric_str = out_b_metric_str.replace(" = ", "_").replace("\n", "-")
110
- out_b_metric_str = out_b_metric_str[:-1] if out_b_metric_str.endswith("-") else out_b_metric_str
111
-
112
- save_path = SAVE_IMG_DIR / f"{dataset.name}+{idx}+{physics.name}+{physics_params_str}+{y_metrics_str}+{model_a.name}+{out_a_metric_str}+{model_b.name}+{out_b_metric_str}.png"
113
- titles = [f"{dataset.name}[{idx}]",
114
- f"y = {physics.name}(x)",
115
- f"{model_a.name}",
116
- f"{model_b.name}"]
117
-
118
- # Pil object -> torch.Tensor
119
- to_tensor = transforms.ToTensor()
120
- x = to_tensor(x)
121
- y = to_tensor(y)
122
- out_a = to_tensor(out_a)
123
- out_b = to_tensor(out_b)
124
-
125
- dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path)
126
 
127
  get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
128
  get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
@@ -139,19 +106,12 @@ def get_physics(physics_name):
139
  baseline = get_baseline_model_on_DEVICE_STR('DPIR')
140
  return get_physics_on_DEVICE_STR(physics_name), baseline
141
 
142
- def get_model(model_name, ckpt_pth):
143
- if model_name in BaselineModel.all_baselines:
144
- return get_baseline_model_on_DEVICE_STR(model_name)
145
- else:
146
- return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
147
-
148
  AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
149
-
150
  def get_dataset(dataset_name):
151
  global AVAILABLE_PHYSICS
152
- if dataset_name = 'MRI':
153
  AVAILABLE_PHYSICS = ['MRI']
154
- elif dataset_name = 'CT':
155
  AVAILABLE_PHYSICS = ['CT']
156
  else:
157
  AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
@@ -197,7 +157,7 @@ with gr.Blocks(title=title, css=custom_css) as interface:
197
  y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
198
  y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
199
 
200
- choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics,
201
  label="List of PhysicsWithGenerator",
202
  value=physics_name)
203
  with gr.Row():
@@ -231,10 +191,10 @@ with gr.Blocks(title=title, css=custom_css) as interface:
231
  use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
232
  with gr.Column(scale=1):
233
  load_button = gr.Button("Load images...")
234
- load_random_button = gr.Button("Load randomly...")
235
 
236
  ### Event listeners
237
- choose_dataset.change(fn=get_dataset_on_DEVICE_STR,
238
  inputs=choose_dataset,
239
  outputs=dataset_placeholder)
240
  choose_physics.change(fn=get_physics,
 
90
  metrics)
91
  return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
95
  get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
 
106
  baseline = get_baseline_model_on_DEVICE_STR('DPIR')
107
  return get_physics_on_DEVICE_STR(physics_name), baseline
108
 
 
 
 
 
 
 
109
  AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
 
110
  def get_dataset(dataset_name):
111
  global AVAILABLE_PHYSICS
112
+ if dataset_name == 'MRI':
113
  AVAILABLE_PHYSICS = ['MRI']
114
+ elif dataset_name == 'CT':
115
  AVAILABLE_PHYSICS = ['CT']
116
  else:
117
  AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
 
157
  y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
158
  y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
159
 
160
+ choose_physics = gr.Radio(choices=AVAILABLE_PHYSICS,
161
  label="List of PhysicsWithGenerator",
162
  value=physics_name)
163
  with gr.Row():
 
191
  use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
192
  with gr.Column(scale=1):
193
  load_button = gr.Button("Load images...")
194
+ load_random_button = gr.Button("Load randomly...")
195
 
196
  ### Event listeners
197
+ choose_dataset.change(fn=get_dataset,
198
  inputs=choose_dataset,
199
  outputs=dataset_placeholder)
200
  choose_physics.change(fn=get_physics,