fix params save bug
Browse files- unet/mv_unet.py +2 -2
unet/mv_unet.py
CHANGED
@@ -150,9 +150,9 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
|
|
150 |
):
|
151 |
super().__init__(**{
|
152 |
k: v for k, v in locals().items() if k not in
|
153 |
-
["self", "kwargs", "__class__"]
|
154 |
})
|
155 |
-
|
156 |
add_multiview_processor(
|
157 |
model = self,
|
158 |
enable_filter = lambda name: name.endswith(f"{multiview_attn_position}.processor"),
|
|
|
150 |
):
|
151 |
super().__init__(**{
|
152 |
k: v for k, v in locals().items() if k not in
|
153 |
+
["self", "kwargs", "__class__", "n_views", "num_modalities", "latent_size", "multiview_chain_pose"]
|
154 |
})
|
155 |
+
self.n_views = n_views
|
156 |
add_multiview_processor(
|
157 |
model = self,
|
158 |
enable_filter = lambda name: name.endswith(f"{multiview_attn_position}.processor"),
|