File size: 1,373 Bytes
6dfcb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
from torch import nn
import cwm.eval.Segmentation.utils as utils
from external.raft_interface import RAFTInterface

class SegmentExtractor(nn.Module):
    def __init__(self, num_segments=1, iters=4, motion_range=4):
        self.num_segments = num_segments
        self.iters = iters
        self.motion_range = motion_range
        self.flow_interface = RAFTInterface()

    def get_sampling_dist(self, x, model):
        pass

    def forward(self, x, model, sampling_dist=None):
        """
        x: [B, 3, H, W] a batch of imagenet-normalized image tensor
        model: a pre-trained CWM model
        """
        if not sampling_dist:
            sampling_dist = self.get_sampling_dist(x, model)

        ## Step 1: sample initial moving and static locations from the distribution
        moving_pos = utils.sample_positions_from_dist(num=1, dist=sampling_dist) # [B, num, 2]
        static_pos = utils.sample_positions_from_dist(num=1, dist=(1-sampling_dist)) # [B, num, 2]
        movement = torch.randint(-self.motion_range, self.motion_range, (B, 1, 2)) # [B, 1, 2]

        ## Step 2: compute initial flow maps
        pred = model.get_counterfactual(x, mask, moving_pos=moving_pos, static_pos=static_pos, movement=movement)
        flow = self.flow_interface(x[:, :, 0], pred)

        ## Step 3: iterate to add more moving and static motions