jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
def consolidate_masks(masks: torch.Tensor, num_latents: int, method: str) -> torch.Tensor:
# masks are floats (e.g., 0.0 or 1.0), do not convert to boolean
n_frames = masks.shape[0]
frames_per_latent = 4
expected_frames = (num_latents * frames_per_latent) - 3
if n_frames == expected_frames:
# Perfect match
pass
elif n_frames == 1 or method == "first_only":
# If only one frame, repeat it for all latents
return masks[:1].repeat(num_latents, 1, 1)
elif n_frames > expected_frames:
# Truncate if there are more frames than expected
masks = masks[:expected_frames]
n_frames = expected_frames
else:
# Not enough frames
raise ValueError(
f"For {num_latents} latents, expected {expected_frames} frames "
f"((num_latents * 4) - 3), but got {n_frames} frames"
)
# Add padding frames to form a multiple of frames_per_latent
padding_frame = masks[-1:].clone()
padded_masks = torch.cat([masks, padding_frame.repeat(3, 1, 1)], dim=0)
# Reshape into (num_latents, frames_per_latent, H, W)
grouped_masks = padded_masks.reshape(num_latents, frames_per_latent, *masks.shape[1:])
if method == "select_first":
return grouped_masks[:, 0]
elif method == "select_last":
return grouped_masks[:, -1]
elif method == "union":
return 1 - torch.clamp(grouped_masks.sum(dim=1), max=1)
else:
raise ValueError(f"Unknown consolidation method: {method}")