Spaces:
Sleeping
Sleeping
refactor
Browse files- app.py +2 -6
- factories.py +4 -15
app.py
CHANGED
@@ -1,15 +1,11 @@
|
|
1 |
-
import json
|
2 |
-
import os
|
3 |
import random
|
4 |
import time
|
5 |
from functools import partial
|
6 |
-
from pathlib import Path
|
7 |
from typing import List
|
8 |
|
9 |
import deepinv as dinv
|
10 |
import gradio as gr
|
11 |
import torch
|
12 |
-
from PIL import Image
|
13 |
from torchvision import transforms
|
14 |
|
15 |
from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
@@ -163,9 +159,9 @@ def get_dataset(dataset_name):
|
|
163 |
physics_name = 'CT'
|
164 |
baseline_name = 'DPIR_CT'
|
165 |
else:
|
166 |
-
available_physics = ['
|
167 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
168 |
-
physics_name = '
|
169 |
baseline_name = 'DPIR'
|
170 |
|
171 |
dataset = get_dataset_on_DEVICE_STR(dataset_name)
|
|
|
|
|
|
|
1 |
import random
|
2 |
import time
|
3 |
from functools import partial
|
|
|
4 |
from typing import List
|
5 |
|
6 |
import deepinv as dinv
|
7 |
import gradio as gr
|
8 |
import torch
|
|
|
9 |
from torchvision import transforms
|
10 |
|
11 |
from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
|
|
159 |
physics_name = 'CT'
|
160 |
baseline_name = 'DPIR_CT'
|
161 |
else:
|
162 |
+
available_physics = ['MotionBlur_medium', 'MotionBlur_hard',
|
163 |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
164 |
+
physics_name = 'MotionBlur_hard'
|
165 |
baseline_name = 'DPIR'
|
166 |
|
167 |
dataset = get_dataset_on_DEVICE_STR(dataset_name)
|
factories.py
CHANGED
@@ -37,7 +37,7 @@ DEFAULT_MODEL_PARAMS = {
|
|
37 |
|
38 |
class PhysicsWithGenerator(torch.nn.Module):
|
39 |
"""Interface between Physics, Generator and Gradio."""
|
40 |
-
all_physics = ["
|
41 |
"GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
|
42 |
"MRI", "CT"]
|
43 |
|
@@ -48,17 +48,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
48 |
if self.name not in self.all_physics:
|
49 |
raise ValueError(f"{self.name} is unavailable.")
|
50 |
|
51 |
-
if self.name == "
|
52 |
-
psf_size = 31
|
53 |
-
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01),
|
54 |
-
padding="valid", device=device_str)
|
55 |
-
self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str)
|
56 |
-
self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
|
57 |
-
self.saved_params = {"updatable_params": {"sigma": 0.01},
|
58 |
-
"updatable_params_converter": {"sigma": float},
|
59 |
-
"fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
|
60 |
-
"psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
|
61 |
-
elif self.name == "MotionBlur_medium":
|
62 |
psf_size = 31
|
63 |
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05),
|
64 |
padding="valid", device=device_str)
|
@@ -128,8 +118,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
128 |
elif self.name == "CT":
|
129 |
acceleration_factor = 10
|
130 |
img_h = 512
|
131 |
-
angles = int(img_h / acceleration_factor)
|
132 |
-
# angles = torch.linspace(0, 180, steps=10)
|
133 |
self.physics = dinv.physics.Tomography(
|
134 |
img_width=img_h,
|
135 |
angles=angles,
|
@@ -188,7 +177,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
188 |
self.physics.update(**kwargs)
|
189 |
|
190 |
def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
|
191 |
-
if self.name in ["
|
192 |
use_gen = True
|
193 |
elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
|
194 |
use_gen = True
|
|
|
37 |
|
38 |
class PhysicsWithGenerator(torch.nn.Module):
|
39 |
"""Interface between Physics, Generator and Gradio."""
|
40 |
+
all_physics = ["MotionBlur_medium", "MotionBlur_hard",
|
41 |
"GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
|
42 |
"MRI", "CT"]
|
43 |
|
|
|
48 |
if self.name not in self.all_physics:
|
49 |
raise ValueError(f"{self.name} is unavailable.")
|
50 |
|
51 |
+
if self.name == "MotionBlur_medium":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
psf_size = 31
|
53 |
self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05),
|
54 |
padding="valid", device=device_str)
|
|
|
118 |
elif self.name == "CT":
|
119 |
acceleration_factor = 10
|
120 |
img_h = 512
|
121 |
+
angles = torch.linspace(0, 180, steps=int(img_h / acceleration_factor))
|
|
|
122 |
self.physics = dinv.physics.Tomography(
|
123 |
img_width=img_h,
|
124 |
angles=angles,
|
|
|
177 |
self.physics.update(**kwargs)
|
178 |
|
179 |
def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
|
180 |
+
if self.name in ["MotionBlur_medium", "MotionBlur_hard", "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard"] and not hasattr(self.physics, "filter"):
|
181 |
use_gen = True
|
182 |
elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
|
183 |
use_gen = True
|