msong97 commited on
Commit
6f0291c
·
1 Parent(s): 14f8df1

[Fix] PSNR computation

Browse files
Files changed (2) hide show
  1. app.py +9 -10
  2. factories.py +1 -1
app.py CHANGED
@@ -52,6 +52,7 @@ def generate_imgs_from_user(image,
52
  x = torch.cat((x, torch.zeros_like(x)), dim=1)
53
 
54
  return generate_imgs(x, physics, use_gen, baseline, model, metrics)
 
55
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
56
  physics: PhysicsWithGenerator, use_gen: bool,
57
  baseline: BaselineModel, model: EvalModel,
@@ -108,26 +109,24 @@ def generate_imgs(x: torch.Tensor,
108
  if out_baseline.shape != out.shape:
109
  out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
110
 
 
 
 
 
 
 
111
  ### Metrics
112
  metrics_y = ""
113
  metrics_out = ""
114
  metrics_out_baseline = ""
115
  for metric in metrics:
116
- if y.shape == x.shape:
117
- metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
118
  metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
119
  metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
120
  metrics_out += f"Inference time = {ram_time:.3f}s"
121
  metrics_out_baseline += f"Inference time = {dpir_time:.3f}s"
122
 
123
- ### Process y when y shape is different from x shape
124
- if physics.name == "MRI":
125
- y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
126
- elif physics.name == "CT":
127
- y_plot = physics.physics.A_adjoint(y)
128
- else:
129
- y_plot = y.clone()
130
-
131
  ### Processing images for plotting :
132
  # - clip value outside of [0,1]
133
  # - shape (1, C, H, W) -> (C, H, W)
 
52
  x = torch.cat((x, torch.zeros_like(x)), dim=1)
53
 
54
  return generate_imgs(x, physics, use_gen, baseline, model, metrics)
55
+
56
  def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
57
  physics: PhysicsWithGenerator, use_gen: bool,
58
  baseline: BaselineModel, model: EvalModel,
 
109
  if out_baseline.shape != out.shape:
110
  out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
111
 
112
+ ### Process y when y shape is different from x shape
113
+ if physics.name == 'MRI' and physics.name == 'CT':
114
+ y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
115
+ else:
116
+ y_plot = y.clone()
117
+
118
  ### Metrics
119
  metrics_y = ""
120
  metrics_out = ""
121
  metrics_out_baseline = ""
122
  for metric in metrics:
123
+ #if y.shape == x.shape:
124
+ metrics_y += f"{metric.name} = {metric(y_plot, x).item():.4f}" + "\n"
125
  metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
126
  metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
127
  metrics_out += f"Inference time = {ram_time:.3f}s"
128
  metrics_out_baseline += f"Inference time = {dpir_time:.3f}s"
129
 
 
 
 
 
 
 
 
 
130
  ### Processing images for plotting :
131
  # - clip value outside of [0,1]
132
  # - shape (1, C, H, W) -> (C, H, W)
factories.py CHANGED
@@ -140,7 +140,7 @@ class PhysicsWithGenerator(torch.nn.Module):
140
  max_iter=10,
141
  )
142
  self.physics_generator = None
143
- self.generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str)
144
  self.saved_params = {"updatable_params": {"sigma": 1e-4},
145
  "updatable_params_converter": {"sigma": float},
146
  "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
 
140
  max_iter=10,
141
  )
142
  self.physics_generator = None
143
+ self.generator = SigmaGenerator(sigma_min=1e-5, sigma_max=1e-4, device=device_str)
144
  self.saved_params = {"updatable_params": {"sigma": 1e-4},
145
  "updatable_params_converter": {"sigma": float},
146
  "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,