Spaces:
Sleeping
Sleeping
[Fix] PSNR computation
Browse files- app.py +9 -10
- factories.py +1 -1
app.py
CHANGED
@@ -52,6 +52,7 @@ def generate_imgs_from_user(image,
|
|
52 |
x = torch.cat((x, torch.zeros_like(x)), dim=1)
|
53 |
|
54 |
return generate_imgs(x, physics, use_gen, baseline, model, metrics)
|
|
|
55 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
56 |
physics: PhysicsWithGenerator, use_gen: bool,
|
57 |
baseline: BaselineModel, model: EvalModel,
|
@@ -108,26 +109,24 @@ def generate_imgs(x: torch.Tensor,
|
|
108 |
if out_baseline.shape != out.shape:
|
109 |
out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
|
110 |
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
### Metrics
|
112 |
metrics_y = ""
|
113 |
metrics_out = ""
|
114 |
metrics_out_baseline = ""
|
115 |
for metric in metrics:
|
116 |
-
if y.shape == x.shape:
|
117 |
-
|
118 |
metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
|
119 |
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
|
120 |
metrics_out += f"Inference time = {ram_time:.3f}s"
|
121 |
metrics_out_baseline += f"Inference time = {dpir_time:.3f}s"
|
122 |
|
123 |
-
### Process y when y shape is different from x shape
|
124 |
-
if physics.name == "MRI":
|
125 |
-
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
|
126 |
-
elif physics.name == "CT":
|
127 |
-
y_plot = physics.physics.A_adjoint(y)
|
128 |
-
else:
|
129 |
-
y_plot = y.clone()
|
130 |
-
|
131 |
### Processing images for plotting :
|
132 |
# - clip value outside of [0,1]
|
133 |
# - shape (1, C, H, W) -> (C, H, W)
|
|
|
52 |
x = torch.cat((x, torch.zeros_like(x)), dim=1)
|
53 |
|
54 |
return generate_imgs(x, physics, use_gen, baseline, model, metrics)
|
55 |
+
|
56 |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
57 |
physics: PhysicsWithGenerator, use_gen: bool,
|
58 |
baseline: BaselineModel, model: EvalModel,
|
|
|
109 |
if out_baseline.shape != out.shape:
|
110 |
out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
|
111 |
|
112 |
+
### Process y when y shape is different from x shape
|
113 |
+
if physics.name == 'MRI' and physics.name == 'CT':
|
114 |
+
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
|
115 |
+
else:
|
116 |
+
y_plot = y.clone()
|
117 |
+
|
118 |
### Metrics
|
119 |
metrics_y = ""
|
120 |
metrics_out = ""
|
121 |
metrics_out_baseline = ""
|
122 |
for metric in metrics:
|
123 |
+
#if y.shape == x.shape:
|
124 |
+
metrics_y += f"{metric.name} = {metric(y_plot, x).item():.4f}" + "\n"
|
125 |
metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
|
126 |
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
|
127 |
metrics_out += f"Inference time = {ram_time:.3f}s"
|
128 |
metrics_out_baseline += f"Inference time = {dpir_time:.3f}s"
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
### Processing images for plotting :
|
131 |
# - clip value outside of [0,1]
|
132 |
# - shape (1, C, H, W) -> (C, H, W)
|
factories.py
CHANGED
@@ -140,7 +140,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
140 |
max_iter=10,
|
141 |
)
|
142 |
self.physics_generator = None
|
143 |
-
self.generator = SigmaGenerator(sigma_min=
|
144 |
self.saved_params = {"updatable_params": {"sigma": 1e-4},
|
145 |
"updatable_params_converter": {"sigma": float},
|
146 |
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
|
|
|
140 |
max_iter=10,
|
141 |
)
|
142 |
self.physics_generator = None
|
143 |
+
self.generator = SigmaGenerator(sigma_min=1e-5, sigma_max=1e-4, device=device_str)
|
144 |
self.saved_params = {"updatable_params": {"sigma": 1e-4},
|
145 |
"updatable_params_converter": {"sigma": float},
|
146 |
"fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
|