Luffuly commited on
Commit
cc10fb7
·
1 Parent(s): c94c385

fix params save bug

Browse files
Files changed (1) hide show
  1. unet/mv_unet.py +2 -2
unet/mv_unet.py CHANGED
@@ -150,9 +150,9 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
150
  ):
151
  super().__init__(**{
152
  k: v for k, v in locals().items() if k not in
153
- ["self", "kwargs", "__class__"]
154
  })
155
-
156
  add_multiview_processor(
157
  model = self,
158
  enable_filter = lambda name: name.endswith(f"{multiview_attn_position}.processor"),
 
150
  ):
151
  super().__init__(**{
152
  k: v for k, v in locals().items() if k not in
153
+ ["self", "kwargs", "__class__", "n_views", "num_modalities", "latent_size", "multiview_chain_pose"]
154
  })
155
+ self.n_views = n_views
156
  add_multiview_processor(
157
  model = self,
158
  enable_filter = lambda name: name.endswith(f"{multiview_attn_position}.processor"),