add option n_view
Browse files- unet/mv_unet.py +1 -0
unet/mv_unet.py
CHANGED
@@ -159,6 +159,7 @@ class UnifieldWrappedUNet(UNet2DConditionModel):
|
|
159 |
num_modalities = num_modalities,
|
160 |
base_img_size = latent_size,
|
161 |
chain_pos = multiview_chain_pose,
|
|
|
162 |
)
|
163 |
|
164 |
switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
|
|
|
159 |
num_modalities = num_modalities,
|
160 |
base_img_size = latent_size,
|
161 |
chain_pos = multiview_chain_pose,
|
162 |
+
views=n_views
|
163 |
)
|
164 |
|
165 |
switch_multiview_processor(self, enable_filter=lambda name: name.endswith(f"{multiview_attn_position}.processor"))
|