Yonuts commited on
Commit
f4e9220
·
1 Parent(s): bc1b3b2

remove model choices

Browse files
Files changed (3) hide show
  1. app.py +45 -62
  2. datasets.py +1 -1
  3. evals.py +48 -139
app.py CHANGED
@@ -20,7 +20,7 @@ DEVICE_STR = 'cuda'
20
  ### Gradio Utils
21
  def generate_imgs(dataset: EvalDataset, idx: int,
22
  model: EvalModel, baseline: BaselineModel,
23
- physics: PhysicsWithGenerator, use_gen: bool,
24
  metrics: List[Metric]):
25
  ### Load 1 image
26
  x = dataset[idx] # shape : (3, 256, 256)
@@ -80,7 +80,7 @@ def update_random_idx_and_generate_imgs(dataset: EvalDataset,
80
  physics: PhysicsWithGenerator,
81
  use_gen: bool,
82
  metrics: List[Metric]):
83
- idx = random.randint(0, len(dataset))
84
  x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
85
  idx,
86
  model,
@@ -125,10 +125,19 @@ def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator,
125
  dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path)
126
 
127
  get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
128
- get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
129
  get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
 
130
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
131
- get_physics_generator_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
 
 
 
 
 
 
 
 
 
132
 
133
  def get_model(model_name, ckpt_pth):
134
  if model_name in BaselineModel.all_baselines:
@@ -154,18 +163,14 @@ with gr.Blocks(title=title, css=custom_css) as interface:
154
 
155
  # Loading things
156
  model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
157
- model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
158
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
159
- physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("MotionBlur_easy")) # lambda expression to instanciate a callable in a gr.State
160
  metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
161
 
162
- @gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
163
- def dynamic_layout(model_a, model_b, dataset, physics, metrics):
164
  ### LAYOUT
165
- model_a_name = model_a.base_name
166
- model_a_full_name = model_a.name
167
- model_b_name = model_b.base_name
168
- model_b_full_name = model_b.name
169
  dataset_name = dataset.name
170
  physics_name = physics.name
171
  metric_names = [metric.name for metric in metrics]
@@ -180,87 +185,65 @@ with gr.Blocks(title=title, css=custom_css) as interface:
180
  with gr.Column():
181
  y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
182
  y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
183
- with gr.Row():
184
- choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
185
- label="List of EvalDataset",
186
- value=dataset_name,
187
- scale=2)
188
- idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1)
189
-
190
  choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics,
191
  label="List of PhysicsWithGenerator",
192
  value=physics_name)
193
  with gr.Row():
194
- key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
195
  label="Updatable Parameter Key",
196
  scale=2)
197
  value_text = gr.Textbox(label="Update Value", scale=2)
198
- with gr.Column(scale=1):
199
- update_button = gr.Button("Update Param")
200
- use_generator_button = gr.Checkbox(label="Use param generator")
201
-
202
  with gr.Column():
203
- with gr.Row():
204
  with gr.Column():
205
- model_a_out = gr.Image(label=f"{model_a_full_name} OUTPUT", interactive=False)
206
- out_a_metric = gr.Textbox(label="Metrics(model_a(y), x)", elem_classes=["fixed-textbox"])
207
- load_model_a = gr.Button("Load model A...", scale=1)
208
  with gr.Column():
209
- model_b_out = gr.Image(label=f"{model_b_full_name} OUTPUT", interactive=False)
210
- out_b_metric = gr.Textbox(label="Metrics(model_b(y), x)", elem_classes=["fixed-textbox"])
211
- load_model_b = gr.Button("Load model B...", scale=1)
212
- with gr.Row():
213
- choose_model_a = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
214
- label="List of Model A",
215
- value=model_a_name,
216
- scale=2)
217
- path_a_str = gr.Textbox(value=model_a.ckpt_pth, label="Checkpoint path", scale=3)
218
  with gr.Row():
219
- choose_model_b = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
220
- label="List of Model B",
221
- value=model_b_name,
222
- scale=2)
223
- path_b_str = gr.Textbox(value=model_b.ckpt_pth, label="Checkpoint path", scale=3)
224
-
225
- # Components: Load Metric + Load/Save Buttons
226
  with gr.Row():
227
- with gr.Column():
228
  choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics,
229
  value=metric_names,
230
  label="Choose metrics you are interested")
231
- with gr.Column():
 
232
  load_button = gr.Button("Load images...")
233
- load_random_button = gr.Button("Load randomly...")
234
- save_button = gr.Button("Save images...")
235
 
236
  ### Event listeners
237
  choose_dataset.change(fn=get_dataset_on_DEVICE_STR,
238
  inputs=choose_dataset,
239
  outputs=dataset_placeholder)
240
- choose_physics.change(fn=get_physics_generator_on_DEVICE_STR,
241
  inputs=choose_physics,
242
- outputs=physics_placeholder)
243
  update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
244
- load_model_a.click(fn=get_model,
245
- inputs=[choose_model_a, path_a_str],
246
- outputs=model_a_placeholder)
247
- load_model_b.click(fn=get_model,
248
- inputs=[choose_model_b, path_b_str],
249
- outputs=model_b_placeholder)
250
  choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
251
  inputs=choose_metrics,
252
  outputs=metrics_placeholder)
253
- load_button.click(fn=generate_imgs,
254
- inputs=[dataset_placeholder,
255
- idx_slider,
256
  model_a_placeholder,
257
  model_b_placeholder,
258
  physics_placeholder,
259
  use_generator_button,
260
- metrics_placeholder],
261
  outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
262
- load_random_button.click(fn=update_random_idx_and_generate_imgs,
263
- inputs=[dataset_placeholder,
264
  model_a_placeholder,
265
  model_b_placeholder,
266
  physics_placeholder,
 
20
  ### Gradio Utils
21
  def generate_imgs(dataset: EvalDataset, idx: int,
22
  model: EvalModel, baseline: BaselineModel,
23
+ physics: PhysicsWithGenerator, use_gen: bool,
24
  metrics: List[Metric]):
25
  ### Load 1 image
26
  x = dataset[idx] # shape : (3, 256, 256)
 
80
  physics: PhysicsWithGenerator,
81
  use_gen: bool,
82
  metrics: List[Metric]):
83
+ idx = random.randint(0, len(dataset)-1)
84
  x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
85
  idx,
86
  model,
 
125
  dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path)
126
 
127
  get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
 
128
  get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
129
+ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
130
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
131
+ get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
132
+
133
+ def get_physics(physics_name):
134
+ if physics_name == 'MRI':
135
+ baseline = get_baseline_model_on_DEVICE_STR('DPIR_MRI')
136
+ elif physics_name == 'CT':
137
+ baseline = get_baseline_model_on_DEVICE_STR('DPIR_CT')
138
+ else:
139
+ baseline = get_baseline_model_on_DEVICE_STR('DPIR')
140
+ return get_physics_on_DEVICE_STR(physics_name), baseline
141
 
142
  def get_model(model_name, ckpt_pth):
143
  if model_name in BaselineModel.all_baselines:
 
163
 
164
  # Loading things
165
  model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
166
+ model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR")) # lambda expression to instanciate a callable in a gr.State
167
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
168
+ physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy")) # lambda expression to instanciate a callable in a gr.State
169
  metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
170
 
171
+ @gr.render(inputs=[dataset_placeholder, physics_placeholder, metrics_placeholder])
172
+ def dynamic_layout(dataset, physics, metrics):
173
  ### LAYOUT
 
 
 
 
174
  dataset_name = dataset.name
175
  physics_name = physics.name
176
  metric_names = [metric.name for metric in metrics]
 
185
  with gr.Column():
186
  y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
187
  y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
188
+
 
 
 
 
 
 
189
  choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics,
190
  label="List of PhysicsWithGenerator",
191
  value=physics_name)
192
  with gr.Row():
193
+ key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
194
  label="Updatable Parameter Key",
195
  scale=2)
196
  value_text = gr.Textbox(label="Update Value", scale=2)
197
+ update_button = gr.Button("Manually update parameter value", scale=1)
198
+
 
 
199
  with gr.Column():
200
+ with gr.Row():
201
  with gr.Column():
202
+ model_a_out = gr.Image(label="RAM OUTPUT", interactive=False)
203
+ out_a_metric = gr.Textbox(label="Metrics(RAM(y, physics), x)", elem_classes=["fixed-textbox"])
 
204
  with gr.Column():
205
+ model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False)
206
+ out_b_metric = gr.Textbox(label="Metrics(DPIR(y, physics), x)", elem_classes=["fixed-textbox"])
 
 
 
 
 
 
 
207
  with gr.Row():
208
+ choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
209
+ label="List of EvalDataset",
210
+ value=dataset_name,
211
+ scale=2)
212
+ idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1)
213
+
214
+ # Components: Load Metric + Load image Buttons
215
  with gr.Row():
216
+ with gr.Column(scale=2):
217
  choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics,
218
  value=metric_names,
219
  label="Choose metrics you are interested")
220
+ use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
221
+ with gr.Column(scale=1):
222
  load_button = gr.Button("Load images...")
223
+ load_random_button = gr.Button("Load randomly...")
 
224
 
225
  ### Event listeners
226
  choose_dataset.change(fn=get_dataset_on_DEVICE_STR,
227
  inputs=choose_dataset,
228
  outputs=dataset_placeholder)
229
+ choose_physics.change(fn=get_physics,
230
  inputs=choose_physics,
231
+ outputs=[physics_placeholder, model_b_placeholder])
232
  update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
 
 
 
 
 
 
233
  choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
234
  inputs=choose_metrics,
235
  outputs=metrics_placeholder)
236
+ load_button.click(fn=generate_imgs,
237
+ inputs=[dataset_placeholder,
238
+ idx_slider,
239
  model_a_placeholder,
240
  model_b_placeholder,
241
  physics_placeholder,
242
  use_generator_button,
243
+ metrics_placeholder],
244
  outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
245
+ load_random_button.click(fn=update_random_idx_and_generate_imgs,
246
+ inputs=[dataset_placeholder,
247
  model_a_placeholder,
248
  model_b_placeholder,
249
  physics_placeholder,
datasets.py CHANGED
@@ -93,7 +93,7 @@ class LsdirMiniDataset(torch.utils.data.Dataset):
93
  transform: Optional[Callable] = None,
94
  ) -> None:
95
  self.root = root
96
- self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith(('.png', '.JPEG'))]
97
  self.transform = transform
98
 
99
  def __len__(self) -> int:
 
93
  transform: Optional[Callable] = None,
94
  ) -> None:
95
  self.root = root
96
+ self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith(('.png', '.jpeg'))]
97
  self.transform = transform
98
 
99
  def __len__(self) -> int:
evals.py CHANGED
@@ -47,22 +47,21 @@ class PhysicsWithGenerator(torch.nn.Module):
47
  if self.name not in self.all_physics:
48
  raise ValueError(f"{self.name} is unavailable.")
49
 
50
- self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str)
51
  if self.name == "MotionBlur_easy":
52
  psf_size = 31
53
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid",
54
- device=device_str)
55
- self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
56
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
57
- self.saved_params = {"updatable_params": {"sigma": 0.05},
58
  "updatable_params_converter": {"sigma": float},
59
  "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
60
  "psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
61
  elif self.name == "MotionBlur_medium":
62
  psf_size = 31
63
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid",
64
- device=device_str)
65
- self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
66
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
67
  self.saved_params = {"updatable_params": {"sigma": 0.05},
68
  "updatable_params_converter": {"sigma": float},
@@ -70,62 +69,61 @@ class PhysicsWithGenerator(torch.nn.Module):
70
  "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
71
  elif self.name == "MotionBlur_hard":
72
  psf_size = 31
73
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid",
74
- device=device_str)
75
- self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
76
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
77
- self.saved_params = {"updatable_params": {"sigma": 0.05},
78
  "updatable_params_converter": {"sigma": float},
79
  "fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1,
80
  "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
81
  elif self.name == "GaussianBlur_easy":
82
  psf_size = 31
83
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.01), padding="valid",
84
- device=device_str)
85
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
86
  sigma_min=1.0, sigma_max=1.0,
87
  num_channels=1,
88
- device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
89
- self.generator = self.physics_generator
90
- self.saved_params = {"updatable_params": {},
91
- "updatable_params_converter": {},
92
- "fixed_params": {"noise_sigma": 0.01, "blur_sigma": 1.0,
93
- "psf_size": 31, "num_channels": 1}}
94
  elif self.name == "GaussianBlur_medium":
95
  psf_size = 31
96
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid",
97
- device=device_str)
98
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
99
  sigma_min=2.0, sigma_max=2.0,
100
  num_channels=1,
101
- device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
102
- self.generator = self.physics_generator
103
- self.saved_params = {"updatable_params": {},
104
- "updatable_params_converter": {},
105
- "fixed_params": {"noise_sigma": 0.05, "blur_sigma": 2.0,
106
- "psf_size": 31, "num_channels": 1}}
107
  elif self.name == "GaussianBlur_hard":
108
  psf_size = 31
109
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid",
110
- device=device_str)
111
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
112
  sigma_min=4.0, sigma_max=4.0,
113
  num_channels=1,
114
- device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
115
- self.generator = self.physics_generator
116
- self.saved_params = {"updatable_params": {},
117
- "updatable_params_converter": {},
118
- "fixed_params": {"noise_sigma": 0.1, "blur_sigma": 4.0,
119
- "psf_size": 31, "num_channels": 1}}
120
  elif self.name == "MRI":
121
- self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01),
122
- device=device_str)
123
  self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
124
- self.generator = self.physics_generator # + self.sigma_generator
125
- self.saved_params = {"updatable_params": {"sigma": 0.05},
126
  "updatable_params_converter": {"sigma": float},
127
- "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
128
- "acceleration_factor": 4}}
129
  elif self.name == "CT":
130
  acceleration_factor = 10
131
  img_h = 480
@@ -141,10 +139,10 @@ class PhysicsWithGenerator(torch.nn.Module):
141
  max_iter=10,
142
  )
143
  self.physics_generator = None
144
- self.generator = self.sigma_generator
145
- self.saved_params = {"updatable_params": {"sigma": 0.1},
146
  "updatable_params_converter": {"sigma": float},
147
- "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.,
148
  "angles": angles, "max_iter": 10}}
149
 
150
  def display_saved_params(self) -> str:
@@ -189,7 +187,7 @@ class PhysicsWithGenerator(torch.nn.Module):
189
  self.physics.update(**kwargs)
190
 
191
  def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
192
- if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"):
193
  use_gen = True
194
  elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
195
  use_gen = True
@@ -247,8 +245,7 @@ class BaselineModel(torch.nn.Module):
247
  -> BaselineModel should be models that are already trained and will have fixed weights.
248
  -> Eval model will change depending on differents checkpoints.
249
  """
250
- all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR",
251
- "DPIR_MRI", "DPIR_CT", "PDNET"]
252
 
253
  def __init__(self, model_name: str, device_str: str = "cpu") -> None:
254
  super().__init__()
@@ -257,58 +254,6 @@ class BaselineModel(torch.nn.Module):
257
  self.name = self.base_name
258
  if self.name not in self.all_baselines:
259
  raise ValueError(f"{self.name} is unavailable.")
260
- elif self.name == "DRUNET":
261
- n_channels = 3
262
- ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
263
- self.model = dinv.models.DRUNet(in_channels=n_channels,
264
- out_channels=n_channels,
265
- device=device_str,
266
- pretrained=ckpt_pth)
267
- self.model.eval() # Set the model to evaluation mode
268
- elif self.name == 'PDNET':
269
- ckpt_pth = "ckpt/pdnet.pth.tar"
270
- self.model = get_model(model_name='pdnet',
271
- device=device_str)
272
- self.model.eval()
273
- self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict'])
274
- elif self.name == "SWINIRx2":
275
- n_channels = 3
276
- scale = 2
277
- ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth"
278
- upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
279
- self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
280
- img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
281
- num_heads=[6, 6, 6, 6, 6, 6],
282
- mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
283
- pretrained=ckpt_pth)
284
- self.model.to(device_str)
285
- self.model.eval() # Set the model to evaluation mode
286
- elif self.name == "SWINIRx4":
287
- n_channels = 3
288
- scale = 4
289
- ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth"
290
- upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
291
- self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
292
- img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
293
- num_heads=[6, 6, 6, 6, 6, 6],
294
- mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
295
- pretrained=ckpt_pth)
296
- self.model.to(device_str)
297
- self.model.eval() # Set the model to evaluation mode
298
-
299
- elif self.name == "PnP-PGD-DRUNET":
300
- n_channels = 3
301
- ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
302
- drunet = dinv.models.DRUNet(in_channels=n_channels,
303
- out_channels=n_channels,
304
- device=device_str,
305
- pretrained=ckpt_pth)
306
- drunet.eval() # Set the model to evaluation mode
307
- self.model = dinv.optim.optim_builder(iteration="PGD",
308
- prior=dinv.optim.PnP(drunet).to(device_str),
309
- data_fidelity=dinv.optim.L2(),
310
- max_iter=20,
311
- params_algo={'stepsize': 1., 'g_param': .05})
312
  elif self.name == "DPIR":
313
  n_channels = 3
314
  ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
@@ -418,11 +363,7 @@ class BaselineModel(torch.nn.Module):
418
  return lamb, list(sigma_denoiser), list(stepsize), max_iter
419
 
420
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
421
- if self.name == "DRUNET":
422
- return self.model(y, sigma=physics.noise_model.sigma)
423
- elif self.name == "PnP-PGD-DRUNET":
424
- return self.model(y, physics=physics)
425
- elif self.name == "DPIR":
426
  # Set the DPIR algorithm parameters
427
  sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
428
  max_iter = 8
@@ -460,7 +401,7 @@ class BaselineModel(torch.nn.Module):
460
  params_algo=params_algo,
461
  )
462
  return model(y, physics=physics)
463
- elif self.name == "DPIR_CT":
464
  # Set the DPIR algorithm parameters
465
  sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
466
  lip_const = physics.compute_norm(physics.A_adjoint(y))
@@ -485,42 +426,10 @@ class BaselineModel(torch.nn.Module):
485
  custom_init=custom_init
486
  )
487
  return algo(y, physics=physics)
488
- elif self.name == 'SWINIRx4':
489
- window_size = 8
490
- scale = 4
491
- _, _, h_old, w_old = y.size()
492
- h_pad = (h_old // window_size + 1) * window_size - h_old
493
- w_pad = (w_old // window_size + 1) * window_size - w_old
494
- img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
495
- img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
496
- output = self.model(img_lq)
497
- output = output[..., :h_old * scale, :w_old * scale]
498
- output = self.circular_roll(output, -2, -2)
499
- # check shape of adjoint
500
- x_adj = physics.A_adjoint(y)
501
- output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
502
- return output
503
- elif self.name == 'SWINIRx2':
504
- window_size = 8
505
- scale = 2
506
- _, _, h_old, w_old = y.size()
507
- h_pad = (h_old // window_size + 1) * window_size - h_old
508
- w_pad = (w_old // window_size + 1) * window_size - w_old
509
- img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
510
- img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
511
- output = self.model(img_lq)
512
- output = output[..., :h_old * scale, :w_old * scale]
513
- output = self.circular_roll(output, -1, -1)
514
- # check shape of adjoint
515
- x_adj = physics.A_adjoint(y)
516
- output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
517
- return output
518
- else:
519
- return self.model(y)
520
 
521
 
522
  class EvalDataset(torch.utils.data.Dataset):
523
- """"""
524
  all_datasets = ["Natural", "MRI", "CT"]
525
 
526
  def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
 
47
  if self.name not in self.all_physics:
48
  raise ValueError(f"{self.name} is unavailable.")
49
 
 
50
  if self.name == "MotionBlur_easy":
51
  psf_size = 31
52
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01),
53
+ padding="valid", device=device_str)
54
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str)
55
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
56
+ self.saved_params = {"updatable_params": {"sigma": 0.01},
57
  "updatable_params_converter": {"sigma": float},
58
  "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
59
  "psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
60
  elif self.name == "MotionBlur_medium":
61
  psf_size = 31
62
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05),
63
+ padding="valid", device=device_str)
64
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str)
65
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
66
  self.saved_params = {"updatable_params": {"sigma": 0.05},
67
  "updatable_params_converter": {"sigma": float},
 
69
  "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
70
  elif self.name == "MotionBlur_hard":
71
  psf_size = 31
72
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1),
73
+ padding="valid", device=device_str)
74
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str)
75
  self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
76
+ self.saved_params = {"updatable_params": {"sigma": 0.1},
77
  "updatable_params_converter": {"sigma": float},
78
  "fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1,
79
  "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
80
  elif self.name == "GaussianBlur_easy":
81
  psf_size = 31
82
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.01),
83
+ padding="valid", device=device_str)
84
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
85
  sigma_min=1.0, sigma_max=1.0,
86
  num_channels=1,
87
+ device=device_str)
88
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
89
+ self.saved_params = {"updatable_params": {"sigma": 0.01},
90
+ "updatable_params_converter": {"sigma": float},
91
+ "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
92
+ "blur_sigma": 1.0, "psf_size": 31, "num_channels": 1}}
93
  elif self.name == "GaussianBlur_medium":
94
  psf_size = 31
95
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05),
96
+ padding="valid", device=device_str)
97
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
98
  sigma_min=2.0, sigma_max=2.0,
99
  num_channels=1,
100
+ device=device_str)
101
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
102
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
103
+ "updatable_params_converter": {"sigma": float},
104
+ "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
105
+ "blur_sigma": 2.0, "psf_size": 31, "num_channels": 1}}
106
  elif self.name == "GaussianBlur_hard":
107
  psf_size = 31
108
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05),
109
+ padding="valid", device=device_str)
110
  self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size),
111
  sigma_min=4.0, sigma_max=4.0,
112
  num_channels=1,
113
+ device=device_str)
114
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
115
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
116
+ "updatable_params_converter": {"sigma": float},
117
+ "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
118
+ "blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
119
  elif self.name == "MRI":
120
+ self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01),
121
+ img_size=(640, 320), device=device_str)
122
  self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
123
+ self.generator = self.physics_generator
124
+ self.saved_params = {"updatable_params": {"sigma": 0.01},
125
  "updatable_params_converter": {"sigma": float},
126
+ "fixed_params": {"acceleration_factor": 4}}
 
127
  elif self.name == "CT":
128
  acceleration_factor = 10
129
  img_h = 480
 
139
  max_iter=10,
140
  )
141
  self.physics_generator = None
142
+ self.generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str)
143
+ self.saved_params = {"updatable_params": {"sigma": 1e-4},
144
  "updatable_params_converter": {"sigma": float},
145
+ "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
146
  "angles": angles, "max_iter": 10}}
147
 
148
  def display_saved_params(self) -> str:
 
187
  self.physics.update(**kwargs)
188
 
189
  def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
190
+ if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard"] and not hasattr(self.physics, "filter"):
191
  use_gen = True
192
  elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
193
  use_gen = True
 
245
  -> BaselineModel should be models that are already trained and will have fixed weights.
246
  -> Eval model will change depending on differents checkpoints.
247
  """
248
+ all_baselines = ["DPIR", "DPIR_MRI", "DPIR_CT"]
 
249
 
250
  def __init__(self, model_name: str, device_str: str = "cpu") -> None:
251
  super().__init__()
 
254
  self.name = self.base_name
255
  if self.name not in self.all_baselines:
256
  raise ValueError(f"{self.name} is unavailable.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  elif self.name == "DPIR":
258
  n_channels = 3
259
  ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
 
363
  return lamb, list(sigma_denoiser), list(stepsize), max_iter
364
 
365
  def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
366
+ if self.name == "DPIR":
 
 
 
 
367
  # Set the DPIR algorithm parameters
368
  sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
369
  max_iter = 8
 
401
  params_algo=params_algo,
402
  )
403
  return model(y, physics=physics)
404
+ else self.name == "DPIR_CT":
405
  # Set the DPIR algorithm parameters
406
  sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
407
  lip_const = physics.compute_norm(physics.A_adjoint(y))
 
426
  custom_init=custom_init
427
  )
428
  return algo(y, physics=physics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
 
431
  class EvalDataset(torch.utils.data.Dataset):
432
+
433
  all_datasets = ["Natural", "MRI", "CT"]
434
 
435
  def __init__(self, dataset_name: str, device_str: str = "cpu") -> None: