|
import torch |
|
from torch import nn |
|
import cwm.eval.Segmentation.utils as utils |
|
from external.raft_interface import RAFTInterface |
|
|
|
class SegmentExtractor(nn.Module): |
|
def __init__(self, num_segments=1, iters=4, motion_range=4): |
|
self.num_segments = num_segments |
|
self.iters = iters |
|
self.motion_range = motion_range |
|
self.flow_interface = RAFTInterface() |
|
|
|
def get_sampling_dist(self, x, model): |
|
pass |
|
|
|
def forward(self, x, model, sampling_dist=None): |
|
""" |
|
x: [B, 3, H, W] a batch of imagenet-normalized image tensor |
|
model: a pre-trained CWM model |
|
""" |
|
if not sampling_dist: |
|
sampling_dist = self.get_sampling_dist(x, model) |
|
|
|
|
|
moving_pos = utils.sample_positions_from_dist(num=1, dist=sampling_dist) |
|
static_pos = utils.sample_positions_from_dist(num=1, dist=(1-sampling_dist)) |
|
movement = torch.randint(-self.motion_range, self.motion_range, (B, 1, 2)) |
|
|
|
|
|
pred = model.get_counterfactual(x, mask, moving_pos=moving_pos, static_pos=static_pos, movement=movement) |
|
flow = self.flow_interface(x[:, :, 0], pred) |
|
|
|
|
|
|
|
|
|
|
|
|