File size: 1,548 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch

def consolidate_masks(masks: torch.Tensor, num_latents: int, method: str) -> torch.Tensor:
    # masks are floats (e.g., 0.0 or 1.0), do not convert to boolean

    n_frames = masks.shape[0]
    frames_per_latent = 4
    expected_frames = (num_latents * frames_per_latent) - 3

    if n_frames == expected_frames:
        # Perfect match
        pass
    elif n_frames == 1 or method == "first_only":
        # If only one frame, repeat it for all latents
        return masks[:1].repeat(num_latents, 1, 1)
    elif n_frames > expected_frames:
        # Truncate if there are more frames than expected
        masks = masks[:expected_frames]
        n_frames = expected_frames
    else:
        # Not enough frames
        raise ValueError(
            f"For {num_latents} latents, expected {expected_frames} frames "
            f"((num_latents * 4) - 3), but got {n_frames} frames"
        )

    # Add padding frames to form a multiple of frames_per_latent
    padding_frame = masks[-1:].clone()
    padded_masks = torch.cat([masks, padding_frame.repeat(3, 1, 1)], dim=0)

    # Reshape into (num_latents, frames_per_latent, H, W)
    grouped_masks = padded_masks.reshape(num_latents, frames_per_latent, *masks.shape[1:])

    if method == "select_first":
        return grouped_masks[:, 0]
    elif method == "select_last":
        return grouped_masks[:, -1]
    elif method == "union":
        return 1 - torch.clamp(grouped_masks.sum(dim=1), max=1)
    else:
        raise ValueError(f"Unknown consolidation method: {method}")