mterris commited on
Commit
3ac1bb3
·
1 Parent(s): dd93c31
Files changed (3) hide show
  1. app.py +3 -5
  2. factories.py +4 -4
  3. requirements.txt +1 -1
app.py CHANGED
@@ -80,16 +80,14 @@ def generate_imgs(x: torch.Tensor,
80
  print(f"[Before inference] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
81
  print(f"[Before inference] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")
82
 
83
- if hasattr(physics.physics, 'imsize'):
84
- physics.physics.imsize = x.shape[1:]
85
- elif hasattr(physics.physics, 'img_size'):
86
- physics.physics.img_size = x.shape[1:]
87
- elif hasattr(physics.physics, 'tensor_size'):
88
  physics.physics.tensor_size = x.shape[1:]
89
 
90
  if physics.physics_generator is not None: # we only change physic params but not noise levels
91
  if hasattr(physics.physics_generator, 'tensor_size'):
92
  physics.physics_generator.tensor_size = x.shape[1:]
 
93
 
94
  ### Compute y
95
  with torch.no_grad():
 
80
  print(f"[Before inference] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
81
  print(f"[Before inference] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")
82
 
83
+
84
+ if hasattr(physics.physics, 'tensor_size'):
 
 
 
85
  physics.physics.tensor_size = x.shape[1:]
86
 
87
  if physics.physics_generator is not None: # we only change physic params but not noise levels
88
  if hasattr(physics.physics_generator, 'tensor_size'):
89
  physics.physics_generator.tensor_size = x.shape[1:]
90
+ physics.generator.tensor_size = x.shape[1:]
91
 
92
  ### Compute y
93
  with torch.no_grad():
factories.py CHANGED
@@ -86,16 +86,16 @@ class PhysicsWithGenerator(torch.nn.Module):
86
  elif self.name == "Inpainting":
87
  sigma = 0.05
88
  split_ratio = 0.3
89
- pixelwise = False
90
- self.physics = dinv.physics.Inpainting(tensor_size=(256, 256), mask=split_ratio,
91
  noise_model=dinv.physics.GaussianNoise(sigma=sigma),
92
  device=device_str)
93
  self.physics_generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
94
  split_ratio=split_ratio, pixelwise=pixelwise,
95
- random_split_ratio=True, device=device_str)
96
  self.generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
97
  split_ratio=split_ratio, pixelwise=pixelwise,
98
- random_split_ratio=True, device=device_str)
99
 
100
  self.saved_params = {"updatable_params": {},
101
  "updatable_params_converter": {"sigma": float},
 
86
  elif self.name == "Inpainting":
87
  sigma = 0.05
88
  split_ratio = 0.3
89
+ pixelwise = True
90
+ self.physics = dinv.physics.Inpainting(tensor_size=(3, 256, 256), mask=split_ratio,
91
  noise_model=dinv.physics.GaussianNoise(sigma=sigma),
92
  device=device_str)
93
  self.physics_generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
94
  split_ratio=split_ratio, pixelwise=pixelwise,
95
+ random_split_ratio=False, device=device_str)
96
  self.generator = dinv.physics.generator.BernoulliSplittingMaskGenerator((3, 256, 256),
97
  split_ratio=split_ratio, pixelwise=pixelwise,
98
+ random_split_ratio=False, device=device_str)
99
 
100
  self.saved_params = {"updatable_params": {},
101
  "updatable_params_converter": {"sigma": float},
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/deepinv/deepinv.git@fixes#egg=deepinv
2
  timm
 
1
+ git+https://github.com/deepinv/deepinv.git
2
  timm