mterris commited on
Commit
88d3587
·
1 Parent(s): 94e6664
Files changed (2) hide show
  1. app.py +2 -6
  2. factories.py +4 -15
app.py CHANGED
@@ -1,15 +1,11 @@
1
- import json
2
- import os
3
  import random
4
  import time
5
  from functools import partial
6
- from pathlib import Path
7
  from typing import List
8
 
9
  import deepinv as dinv
10
  import gradio as gr
11
  import torch
12
- from PIL import Image
13
  from torchvision import transforms
14
 
15
  from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
@@ -163,9 +159,9 @@ def get_dataset(dataset_name):
163
  physics_name = 'CT'
164
  baseline_name = 'DPIR_CT'
165
  else:
166
- available_physics = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
167
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
168
- physics_name = 'MotionBlur_easy'
169
  baseline_name = 'DPIR'
170
 
171
  dataset = get_dataset_on_DEVICE_STR(dataset_name)
 
 
 
1
  import random
2
  import time
3
  from functools import partial
 
4
  from typing import List
5
 
6
  import deepinv as dinv
7
  import gradio as gr
8
  import torch
 
9
  from torchvision import transforms
10
 
11
  from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
 
159
  physics_name = 'CT'
160
  baseline_name = 'DPIR_CT'
161
  else:
162
+ available_physics = ['MotionBlur_medium', 'MotionBlur_hard',
163
  'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
164
+ physics_name = 'MotionBlur_hard'
165
  baseline_name = 'DPIR'
166
 
167
  dataset = get_dataset_on_DEVICE_STR(dataset_name)
factories.py CHANGED
@@ -37,7 +37,7 @@ DEFAULT_MODEL_PARAMS = {
37
 
38
  class PhysicsWithGenerator(torch.nn.Module):
39
  """Interface between Physics, Generator and Gradio."""
40
- all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard",
41
  "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
42
  "MRI", "CT"]
43
 
@@ -48,17 +48,7 @@ class PhysicsWithGenerator(torch.nn.Module):
48
  if self.name not in self.all_physics:
49
  raise ValueError(f"{self.name} is unavailable.")
50
 
51
- if self.name == "MotionBlur_easy":
52
- psf_size = 31
53
- self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01),
54
- padding="valid", device=device_str)
55
- self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str)
56
- self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
57
- self.saved_params = {"updatable_params": {"sigma": 0.01},
58
- "updatable_params_converter": {"sigma": float},
59
- "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
60
- "psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
61
- elif self.name == "MotionBlur_medium":
62
  psf_size = 31
63
  self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05),
64
  padding="valid", device=device_str)
@@ -128,8 +118,7 @@ class PhysicsWithGenerator(torch.nn.Module):
128
  elif self.name == "CT":
129
  acceleration_factor = 10
130
  img_h = 512
131
- angles = int(img_h / acceleration_factor)
132
- # angles = torch.linspace(0, 180, steps=10)
133
  self.physics = dinv.physics.Tomography(
134
  img_width=img_h,
135
  angles=angles,
@@ -188,7 +177,7 @@ class PhysicsWithGenerator(torch.nn.Module):
188
  self.physics.update(**kwargs)
189
 
190
  def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
191
- if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard"] and not hasattr(self.physics, "filter"):
192
  use_gen = True
193
  elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
194
  use_gen = True
 
37
 
38
  class PhysicsWithGenerator(torch.nn.Module):
39
  """Interface between Physics, Generator and Gradio."""
40
+ all_physics = ["MotionBlur_medium", "MotionBlur_hard",
41
  "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard",
42
  "MRI", "CT"]
43
 
 
48
  if self.name not in self.all_physics:
49
  raise ValueError(f"{self.name} is unavailable.")
50
 
51
+ if self.name == "MotionBlur_medium":
 
 
 
 
 
 
 
 
 
 
52
  psf_size = 31
53
  self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05),
54
  padding="valid", device=device_str)
 
118
  elif self.name == "CT":
119
  acceleration_factor = 10
120
  img_h = 512
121
+ angles = torch.linspace(0, 180, steps=int(img_h / acceleration_factor))
 
122
  self.physics = dinv.physics.Tomography(
123
  img_width=img_h,
124
  angles=angles,
 
177
  self.physics.update(**kwargs)
178
 
179
  def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
180
+ if self.name in ["MotionBlur_medium", "MotionBlur_hard", "GaussianBlur_easy", "GaussianBlur_medium", "GaussianBlur_hard"] and not hasattr(self.physics, "filter"):
181
  use_gen = True
182
  elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
183
  use_gen = True