mterris commited on
Commit
ee6f900
·
1 Parent(s): e0e7789

added inpainting

Browse files
Files changed (2) hide show
  1. app.py +3 -3
  2. factories.py +21 -5
app.py CHANGED
@@ -159,7 +159,7 @@ def get_dataset(dataset_name):
159
  physics_name = 'CT'
160
  baseline_name = 'DPIR_CT'
161
  else:
162
- available_physics = ['MotionBlur_medium', 'MotionBlur_hard',
163
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
164
  physics_name = 'MotionBlur_hard'
165
  baseline_name = 'DPIR'
@@ -192,11 +192,11 @@ with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
192
 
193
  ### USER-SPECIFIC VARIABLES
194
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
195
- available_physics_placeholder = gr.State(['MotionBlur_medium', 'MotionBlur_hard',
196
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
197
  # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
198
  # Solution: using lambda expression
199
- physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_medium"))
200
  model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
201
 
202
  print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
 
159
  physics_name = 'CT'
160
  baseline_name = 'DPIR_CT'
161
  else:
162
+ available_physics = ['Inpainting', 'MotionBlur_medium', 'MotionBlur_hard',
163
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
164
  physics_name = 'MotionBlur_hard'
165
  baseline_name = 'DPIR'
 
192
 
193
  ### USER-SPECIFIC VARIABLES
194
  dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
195
+ available_physics_placeholder = gr.State(['Inpainting', 'MotionBlur_medium', 'MotionBlur_hard',
196
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
197
  # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
198
  # Solution: using lambda expression
199
+ physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_hard"))
200
  model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
201
 
202
  print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
factories.py CHANGED
@@ -13,7 +13,7 @@ from physics.blur_generator import GaussianBlurGenerator
13
 
14
  class PhysicsWithGenerator(torch.nn.Module):
15
  """Interface between Physics, Generator and Gradio."""
16
- all_physics = ["MotionBlur_medium", "MotionBlur_hard",
17
  "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
18
  "MRI", "CT"]
19
 
@@ -83,6 +83,22 @@ class PhysicsWithGenerator(torch.nn.Module):
83
  "updatable_params_converter": {"sigma": float},
84
  "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
85
  "blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  elif self.name == "MRI":
87
  self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01).to(device_str),
88
  img_size=(640, 320), device=device_str)
@@ -101,14 +117,14 @@ class PhysicsWithGenerator(torch.nn.Module):
101
  circle=False,
102
  normalize=True,
103
  device=device_str,
104
- noise_model=dinv.physics.GaussianNoise(sigma=1e-4).to(device_str),
105
  max_iter=10,
106
  )
107
- self.physics_generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
108
- self.generator = SigmaGenerator(sigma_min=1e-4, sigma_max=1e-4, device=device_str)
109
  self.saved_params = {"updatable_params": {"sigma": 0.1},
110
  "updatable_params_converter": {"sigma": float},
111
- "fixed_params": {"noise_sigma_min": 1e-4, "noise_sigma_max": 1e-4,
112
  "angles": angles, "max_iter": 10}}
113
 
114
  def display_saved_params(self) -> str:
 
13
 
14
  class PhysicsWithGenerator(torch.nn.Module):
15
  """Interface between Physics, Generator and Gradio."""
16
+ all_physics = ["Inpainting", "MotionBlur_medium", "MotionBlur_hard",
17
  "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
18
  "MRI", "CT"]
19
 
 
83
  "updatable_params_converter": {"sigma": float},
84
  "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
85
  "blur_sigma": 4.0, "psf_size": 31, "num_channels": 1}}
86
+ elif self.name == "Inpainting":
87
+ self.physics = dinv.physics.Inpainting(tensor_size=(256, 256), mask=split_ratio,
88
+ noise_model=dinv.physics.GaussianNoise(sigma=sigma),
89
+ device=device_str)
90
+ self.physics_generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
91
+ split_ratio=split_ratio, pixelwise=pixelwise,
92
+ random_split_ratio=True, min_split_ratio=split_ratio,
93
+ max_split_ratio=split_ratio, device=device_str)
94
+ self.generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
95
+ split_ratio=split_ratio, pixelwise=pixelwise,
96
+ random_split_ratio=True, min_split_ratio=split_ratio,
97
+ max_split_ratio=split_ratio, device=device_str)
98
+
99
+ self.saved_params = {"updatable_params": {},
100
+ "updatable_params_converter": {"sigma": float},
101
+ "fixed_params": {"sigma": sigma}}
102
  elif self.name == "MRI":
103
  self.physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(sigma=.01).to(device_str),
104
  img_size=(640, 320), device=device_str)
 
117
  circle=False,
118
  normalize=True,
119
  device=device_str,
120
+ noise_model=dinv.physics.GaussianNoise(sigma=1e-3).to(device_str),
121
  max_iter=10,
122
  )
123
+ self.physics_generator = SigmaGenerator(sigma_min=1e-3, sigma_max=1e-3, device=device_str)
124
+ self.generator = SigmaGenerator(sigma_min=1e-3, sigma_max=1e-3, device=device_str)
125
  self.saved_params = {"updatable_params": {"sigma": 0.1},
126
  "updatable_params_converter": {"sigma": float},
127
+ "fixed_params": {"noise_sigma_min": 1e-3, "noise_sigma_max": 1e-3,
128
  "angles": angles, "max_iter": 10}}
129
 
130
  def display_saved_params(self) -> str: