mterris commited on
Commit
4b3f23c
·
1 Parent(s): b2d938e
Files changed (2) hide show
  1. app.py +2 -2
  2. factories.py +16 -1
app.py CHANGED
@@ -168,7 +168,7 @@ def get_dataset(dataset_name):
168
  physics_name = 'CT'
169
  baseline_name = 'DPIR_CT'
170
  else:
171
- available_physics = ['Inpainting', 'MotionBlur_medium', 'MotionBlur_hard',
172
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
173
  physics_name = 'MotionBlur_hard'
174
  baseline_name = 'DPIR'
@@ -201,7 +201,7 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
201
 
202
  ### USER-SPECIFIC VARIABLES
203
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
204
- available_physics_placeholder = gr.State(['Inpainting', 'MotionBlur_medium', 'MotionBlur_hard',
205
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
206
  # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
207
  # Solution: using lambda expression
 
168
  physics_name = 'CT'
169
  baseline_name = 'DPIR_CT'
170
  else:
171
+ available_physics = ['Inpainting', 'SR' ,'MotionBlur_medium', 'MotionBlur_hard',
172
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
173
  physics_name = 'MotionBlur_hard'
174
  baseline_name = 'DPIR'
 
201
 
202
  ### USER-SPECIFIC VARIABLES
203
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
204
+ available_physics_placeholder = gr.State(['Inpainting', 'SR', 'MotionBlur_medium', 'MotionBlur_hard',
205
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
206
  # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
207
  # Solution: using lambda expression
factories.py CHANGED
@@ -15,7 +15,7 @@ class PhysicsWithGenerator(torch.nn.Module):
15
  """Interface between Physics, Generator and Gradio."""
16
  all_physics = ["Inpainting", "MotionBlur_medium", "MotionBlur_hard",
17
  "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
18
- "MRI", "CT"]
19
 
20
  def __init__(self, physics_name: str, device_str: str = "cpu") -> None:
21
  super().__init__()
@@ -83,6 +83,21 @@ class PhysicsWithGenerator(torch.nn.Module):
83
  "updatable_params_converter": {"sigma": float},
84
  "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
85
  "blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  elif self.name == "Inpainting":
87
  sigma = 0.05
88
  split_ratio = 0.3
 
15
  """Interface between Physics, Generator and Gradio."""
16
  all_physics = ["Inpainting", "MotionBlur_medium", "MotionBlur_hard",
17
  "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
18
+ "MRI", "CT", "SR"]
19
 
20
  def __init__(self, physics_name: str, device_str: str = "cpu") -> None:
21
  super().__init__()
 
83
  "updatable_params_converter": {"sigma": float},
84
  "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
85
  "blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
86
+ elif self.name == 'SR':
87
+ self.physics = dinv.physics.Downsampling(img_size=(3, 256, 256), normalize=False, filter="bicubic",
88
+ factor=4,
89
+ padding='constant',
90
+ noise_model=dinv.physics.GaussianNoise(sigma=0.1),
91
+ device=device_str)
92
+ list_filters = ["bicubic"]
93
+ list_factors = [2, 4]
94
+ self.physics_generator = dinv.physics.generator.DownsamplingGenerator(filters=list_filters, factors=list_factors,
95
+ device=device_str) + SigmaGenerator(sigma_min=0.0, sigma_max=0.05, device=device_str)
96
+ self.generator = self.physics_generator # here do not add noise
97
+ self.saved_params = {"updatable_params": {"sigma": 0.0, "factor": 4},
98
+ "updatable_params_converter": {"sigma": float, "factor": int},
99
+ "fixed_params": {"noise_sigma_min": 0., "noise_sigma_max": 0.,
100
+ "list_filters": list_filters, "list_factors": list_factors}}
101
  elif self.name == "Inpainting":
102
  sigma = 0.05
103
  split_ratio = 0.3