msong97 commited on
Commit
14f8df1
·
1 Parent(s): 4e6590f

adapt mri mask to image size

Browse files
Files changed (1) hide show
  1. factories.py +4 -0
factories.py CHANGED
@@ -194,6 +194,10 @@ class PhysicsWithGenerator(torch.nn.Module):
194
  use_gen = True
195
 
196
  if use_gen:
 
 
 
 
197
  kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample
198
  self.update_saved_params_and_physics(**kwargs)
199
 
 
194
  use_gen = True
195
 
196
  if use_gen:
197
+ if self.name == 'MRI': # RandomMaskGenerator deoends on image size
198
+ _, _, h, w = x.shape
199
+ self.generator = dinv.physics.generator.RandomMaskGenerator((2, h, w), acceleration_factor=4)
200
+
201
  kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample
202
  self.update_saved_params_and_physics(**kwargs)
203