|
import kornia |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange |
|
from torch import nn |
|
|
|
import cwm.eval.Flow.masking_flow as masking |
|
|
|
|
|
def boltzmann(x, beta=1, eps=1e-9): |
|
if beta is None: |
|
return x |
|
x = torch.exp(x * beta) |
|
return x / x.amax((-1,-2), keepdim=True).clamp(min=eps) |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
def imagenet_normalize(x, temporal_dim=1): |
|
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(x.device)[None,None,:,None,None].to(x) |
|
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(x.device)[None,None,:,None,None].to(x) |
|
if temporal_dim == 2: |
|
mean = mean.transpose(1,2) |
|
std = std.transpose(1,2) |
|
return (x - mean) / std |
|
|
|
def imagenet_unnormalize(x, temporal_dim=2): |
|
device = x.device |
|
mean = torch.as_tensor(IMAGENET_DEFAULT_MEAN).to(device)[None, None, :, None, None].to(x) |
|
std = torch.as_tensor(IMAGENET_DEFAULT_STD).to(device)[None, None, :, None, None].to(x) |
|
if temporal_dim == 2: |
|
mean = mean.transpose(1,2) |
|
std = std.transpose(1,2) |
|
x = x*std + mean |
|
return x |
|
|
|
|
|
|
|
def coordinate_ims(batch_size, seq_length, imsize, normalize=True, dtype_out=torch.float32): |
|
static = False |
|
if seq_length == 0: |
|
static = True |
|
seq_length = 1 |
|
B = batch_size |
|
T = seq_length |
|
H,W = imsize |
|
ones = torch.ones([B,H,W,1], dtype=dtype_out) |
|
if normalize: |
|
h = torch.divide(torch.arange(H).to(ones), torch.tensor(H-1, dtype=dtype_out)) |
|
h = 2.0 * ((h.view(1, H, 1, 1) * ones) - 0.5) |
|
w = torch.divide(torch.arange(W).to(ones), torch.tensor(W-1, dtype=dtype_out)) |
|
w = 2.0 * ((w.view(1, 1, W, 1) * ones) - 0.5) |
|
else: |
|
h = torch.arange(H).to(ones).view(1,H,1,1) * ones |
|
w = torch.arange(W).to(ones).view(1,1,W,1) * ones |
|
h = torch.stack([h]*T, 1) |
|
w = torch.stack([w]*T, 1) |
|
hw_ims = torch.cat([h,w], -1) |
|
if static: |
|
hw_ims = hw_ims[:,0] |
|
return hw_ims |
|
|
|
|
|
def get_distribution_centroid(dist, eps=1e-9, normalize=False): |
|
|
|
B,T,C,H,W = dist.shape |
|
assert C == 1 |
|
dist_sum = dist.sum((-2, -1), keepdim=True).clamp(min=eps) |
|
dist = dist / dist_sum |
|
|
|
grid = coordinate_ims(B, T, [H,W], normalize=normalize).to(dist.device) |
|
grid = grid.permute(0,1,4,2,3) |
|
centroid = (grid * dist).sum((-2,-1)) |
|
return centroid |
|
|
|
|
|
|
|
class FlowToRgb(object): |
|
|
|
def __init__(self, max_speed=1.0, from_image_coordinates=True, from_sampling_grid=False): |
|
self.max_speed = max_speed |
|
self.from_image_coordinates = from_image_coordinates |
|
self.from_sampling_grid = from_sampling_grid |
|
|
|
def __call__(self, flow): |
|
assert flow.size(-3) == 2, flow.shape |
|
if self.from_sampling_grid: |
|
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) |
|
flow_y = -flow_y |
|
elif not self.from_image_coordinates: |
|
flow_x, flow_y = torch.split(flow, [1, 1], dim=-3) |
|
else: |
|
flow_h, flow_w = torch.split(flow, [1,1], dim=-3) |
|
flow_x, flow_y = [flow_w, -flow_h] |
|
|
|
angle = torch.atan2(flow_y, flow_x) |
|
speed = torch.sqrt(flow_x**2 + flow_y**2) / self.max_speed |
|
|
|
hue = torch.fmod(angle, torch.tensor(2 * np.pi)) |
|
sat = torch.ones_like(hue) |
|
val = speed |
|
|
|
hsv = torch.cat([hue, sat, val], -3) |
|
rgb = kornia.color.hsv_to_rgb(hsv) |
|
return rgb |
|
|
|
class Patchify(nn.Module): |
|
"""Convert a set of images or a movie into patch vectors""" |
|
|
|
def __init__(self, |
|
patch_size=(16, 16), |
|
temporal_dim=1, |
|
squeeze_channel_dim=True |
|
): |
|
super().__init__() |
|
self.set_patch_size(patch_size) |
|
self.temporal_dim = temporal_dim |
|
assert self.temporal_dim in [1, 2], self.temporal_dim |
|
self._squeeze_channel_dim = squeeze_channel_dim |
|
|
|
@property |
|
def num_patches(self): |
|
if (self.T is None) or (self.H is None) or (self.W is None): |
|
return None |
|
else: |
|
return (self.T // self.pt) * (self.H // self.ph) * (self.W // self.pw) |
|
|
|
def set_patch_size(self, patch_size): |
|
self.patch_size = patch_size |
|
if len(self.patch_size) == 2: |
|
self.ph, self.pw = self.patch_size |
|
self.pt = 1 |
|
self._patches_are_3d = False |
|
elif len(self.patch_size) == 3: |
|
self.pt, self.ph, self.pw = self.patch_size |
|
self._patches_are_3d = True |
|
else: |
|
raise ValueError("patch_size must be a 2- or 3-tuple, but is %s" % self.patch_size) |
|
|
|
self.shape_inp = self.rank_inp = self.H = self.W = self.T = None |
|
self.D = self.C = self.E = self.embed_dim = None |
|
|
|
def _check_shape(self, x): |
|
self.shape_inp = x.shape |
|
self.rank_inp = len(self.shape_inp) |
|
self.H, self.W = self.shape_inp[-2:] |
|
assert (self.H % self.ph) == 0 and (self.W % self.pw) == 0, (self.shape_inp, self.patch_size) |
|
if (self.rank_inp == 5) and self._patches_are_3d: |
|
self.T = self.shape_inp[self.temporal_dim] |
|
assert (self.T % self.pt) == 0, (self.T, self.pt) |
|
elif self.rank_inp == 5: |
|
self.T = self.shape_inp[self.temporal_dim] |
|
else: |
|
self.T = 1 |
|
|
|
def split_by_time(self, x): |
|
shape = x.shape |
|
assert shape[1] % self.T == 0, (shape, self.T) |
|
return x.view(shape[0], self.T, shape[1] // self.T, *shape[2:]) |
|
|
|
def merge_by_time(self, x): |
|
shape = x.shape |
|
return x.view(shape[0], shape[1] * shape[2], *shape[3:]) |
|
|
|
def video_to_patches(self, x): |
|
if self.rank_inp == 4: |
|
assert self.pt == 1, (self.pt, x.shape) |
|
x = rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw) c', ph=self.ph, pw=self.pw) |
|
else: |
|
assert self.rank_inp == 5, (x.shape, self.rank_inp, self.shape_inp) |
|
dim_order = 'b (t pt) c (h ph) (w pw)' if self.temporal_dim == 1 else 'b c (t pt) (h ph) (w pw)' |
|
x = rearrange(x, dim_order + ' -> b (t h w) (pt ph pw) c', pt=self.pt, ph=self.ph, pw=self.pw) |
|
|
|
self.N, self.D, self.C = x.shape[-3:] |
|
self.embed_dim = self.E = self.D * self.C |
|
return x |
|
|
|
def patches_to_video(self, x): |
|
shape = x.shape |
|
rank = len(shape) |
|
if rank == 4: |
|
B, _N, _D, _C = shape |
|
else: |
|
assert rank == 3, rank |
|
B, _N, _E = shape |
|
assert (_E % self.D == 0), (_E, self.D) |
|
x = x.view(B, _N, self.D, -1) |
|
|
|
if _N < self.num_patches: |
|
masked_patches = self.get_masked_patches( |
|
x, |
|
num_patches=(self.num_patches - _N), |
|
mask_mode=self.mask_mode) |
|
x = torch.cat([x, masked_patches], 1) |
|
|
|
x = rearrange( |
|
x, |
|
'b (t h w) (pt ph pw) c -> b c (t pt) (h ph) (w pw)', |
|
pt=self.pt, ph=self.ph, pw=self.pw, |
|
t=(self.T // self.pt), h=(self.H // self.ph), w=(self.W // self.pw)) |
|
|
|
if self.rank_inp == 5 and (self.temporal_dim == 1): |
|
x = x.transpose(1, 2) |
|
elif self.rank_inp == 4: |
|
assert x.shape[2] == 1, x.shape |
|
x = x[:, :, 0] |
|
return x |
|
|
|
@staticmethod |
|
def get_masked_patches(x, num_patches, mask_mode='zeros'): |
|
shape = x.shape |
|
patches_shape = (shape[0], num_patches, *shape[2:]) |
|
if mask_mode == 'zeros': |
|
return torch.zeros(patches_shape).to(x.device).to(x.dtype).detach() |
|
elif mask_mode == 'gray': |
|
return 0.5 * torch.ones(patches_shape).to(x.device).to(x.dtype).detach() |
|
else: |
|
raise NotImplementedError("Haven't implemented mask_mode == %s" % mask_mode) |
|
|
|
def average_within_patches(self, z): |
|
if len(z.shape) == 3: |
|
z = rearrange(z, 'b n (d c) -> b n d c', c=self.C) |
|
return z.mean(-2, True).repeat(1, 1, z.shape[-2], 1) |
|
|
|
def forward(self, x, to_video=False, mask_mode='zeros'): |
|
if not to_video: |
|
self._check_shape(x) |
|
x = self.video_to_patches(x) |
|
return x if not self._squeeze_channel_dim else x.view(x.size(0), self.N, -1) |
|
|
|
else: |
|
assert (self.shape_inp is not None) and (self.num_patches is not None) |
|
self.mask_mode = mask_mode |
|
x = self.patches_to_video(x) |
|
return x |
|
|
|
|
|
class DerivativeFlowGenerator(nn.Module): |
|
"""Estimate flow of a two-frame predictor using torch autograd""" |
|
|
|
def __init__(self, |
|
predictor, |
|
perturbation_patch_size=None, |
|
aggregation_patch_size=None, |
|
agg_power=None, |
|
agg_channel_func=None, |
|
num_samples=1, |
|
leave_one_out_sampling=False, |
|
average_jacobian=True, |
|
confidence_thresh=None, |
|
temporal_dim=2, |
|
imagenet_normalize_inputs=True): |
|
|
|
super(DerivativeFlowGenerator, self).__init__() |
|
|
|
self.predictor = predictor |
|
|
|
self.patchify = Patchify(self.patch_size, temporal_dim=1, squeeze_channel_dim=True) |
|
|
|
self.set_temporal_dim(temporal_dim) |
|
|
|
self.imagenet_normalize_inputs = imagenet_normalize_inputs |
|
|
|
self.perturbation_patch_size = self._get_patch_size(perturbation_patch_size) or self.patch_size |
|
self.aggregation_patch_size = self._get_patch_size(aggregation_patch_size) or self.patch_size |
|
self.agg_patchify = Patchify(self.aggregation_patch_size, |
|
temporal_dim=1, |
|
squeeze_channel_dim=False) |
|
self.agg_channel_func = agg_channel_func or (lambda x: F.relu(x).sum(-3, True)) |
|
self.average_jacobian = average_jacobian |
|
self.confidence_thresh = confidence_thresh |
|
|
|
self.num_samples = num_samples |
|
self.leave_one_out_sampling = leave_one_out_sampling |
|
self.agg_power = agg_power |
|
self.t_dim = temporal_dim |
|
|
|
def _get_patch_size(self, p): |
|
if p is None: |
|
return None |
|
elif isinstance(p, int): |
|
return (1, p, p) |
|
elif len(p) == 2: |
|
return (1, p[0], p[1]) |
|
else: |
|
assert len(p) == 3, p |
|
return (p[0], p[1], p[2]) |
|
|
|
def set_temporal_dim(self, t_dim): |
|
if t_dim == 1: |
|
self.predictor.t_dim = 1 |
|
self.predictor.c_dim = 2 |
|
elif t_dim == 2: |
|
self.predictor.c_dim = 1 |
|
self.predictor.t_dim = 2 |
|
else: |
|
raise ValueError("temporal_dim must be 1 or 2") |
|
|
|
@property |
|
def c_dim(self): |
|
if self.predictor is None: |
|
return None |
|
return self.predictor.c_dim |
|
|
|
@property |
|
def patch_size(self): |
|
if self.predictor is None: |
|
return None |
|
elif hasattr(self.predictor, 'patch_size'): |
|
return self.predictor.patch_size |
|
elif hasattr(self.predictor.encoder.patch_embed, 'proj'): |
|
return self.predictor.encoder.patch_embed.proj.kernel_size |
|
else: |
|
return None |
|
@property |
|
def S(self): |
|
return self.num_samples |
|
|
|
@property |
|
def sequence_length(self): |
|
if self.predictor is None: |
|
return None |
|
elif hasattr(self.predictor, 'sequence_length'): |
|
return self.predictor.sequence_length |
|
elif hasattr(self.predictor, 'num_frames'): |
|
return self.predictor.num_frames |
|
else: |
|
return 2 |
|
@property |
|
def mask_shape(self): |
|
if self.predictor is None: |
|
return None |
|
elif hasattr(self.predictor, 'mask_shape'): |
|
return self.predictor.mask_shape |
|
|
|
assert self.patch_size is not None |
|
pt, ph, pw = self.patch_size |
|
return (self.sequence_length // pt, |
|
self.inp_shape[-2] // ph, |
|
self.inp_shape[-1] // pw) |
|
|
|
@property |
|
def perturbation_mask_shape(self): |
|
return ( |
|
self.mask_shape[0], |
|
self.inp_shape[-2] // self.perturbation_patch_size[-2], |
|
self.inp_shape[-1] // self.perturbation_patch_size[-1] |
|
) |
|
|
|
|
|
|
|
@property |
|
def p_mask_shape(self): |
|
return self.perturbation_mask_shape |
|
|
|
@property |
|
def aggregation_mask_shape(self): |
|
return ( |
|
1, |
|
self.inp_shape[-2] // self.aggregation_patch_size[-2], |
|
self.inp_shape[-1] // self.aggregation_patch_size[-1] |
|
) |
|
|
|
@property |
|
def a_mask_shape(self): |
|
return self.aggregation_mask_shape |
|
|
|
def get_perturbation_input(self, x): |
|
self.set_input(x) |
|
y = torch.zeros((self.B, *self.p_mask_shape), dtype=x.dtype, device=x.device, requires_grad=True) |
|
y = y.unsqueeze(2).repeat(1, 1, x.shape[2], 1, 1) |
|
return y |
|
|
|
def pred_patches_to_video(self, y, x, mask): |
|
"""input at visible positions, preds at masked positions""" |
|
B, C = y.shape[0], y.shape[-1] |
|
self.patchify._check_shape(x) |
|
self.patchify.D = np.prod(self.patch_size) |
|
x = self.patchify(x) |
|
y_out = torch.zeros_like(x) |
|
x_vis = x[~mask] |
|
|
|
y_out[~mask] = x_vis.view(-1, C) |
|
try: |
|
y_out[mask] = y.view(-1, C) |
|
except: |
|
y_out[mask] = y.reshape(-1, C) |
|
|
|
return self.patchify(y_out, to_video=True) |
|
|
|
def set_image_size(self, *args, **kwargs): |
|
assert self.predictor is not None, "Can't set the image size without a predictor" |
|
if hasattr(self.predictor, 'set_image_size'): |
|
self.predictor.set_image_size(*args, **kwargs) |
|
else: |
|
self.predictor.image_size = args[0] |
|
|
|
def predict(self, x=None, mask=None, forward_full=False): |
|
if x is None: |
|
x = self.x |
|
if mask is None: |
|
mask = self.generate_mask(x) |
|
|
|
self.set_image_size(x.shape[-2:]) |
|
y = self.predictor( |
|
self._preprocess(x), |
|
mask if (x.size(0) == 1) else self.mask_rectangularizer(mask), forward_full=forward_full) |
|
|
|
y = self.pred_patches_to_video(y, x, mask=mask) |
|
|
|
frame = -1 % y.size(1) |
|
y = y[:, frame:frame + 1] |
|
|
|
return y |
|
|
|
def _get_perturbation_func(self, x=None, mask=None): |
|
|
|
if (x is not None): |
|
self.set_input(x, mask) |
|
|
|
def forward_mini_image(y): |
|
y = y.repeat_interleave(self.perturbation_patch_size[-2], -2) |
|
y = y.repeat_interleave(self.perturbation_patch_size[-1], -1) |
|
x_pred = self.predict(self.x + y, self.mask) |
|
x_pred = self.agg_patchify(x_pred).mean(-2).sum(-1).view(self.B, *self.a_mask_shape) |
|
return x_pred[self.targets] |
|
|
|
return forward_mini_image |
|
|
|
def _postprocess_jacobian(self, jac): |
|
_jac = torch.zeros((self.B, *self.a_mask_shape, *jac.shape[1:])).to(jac.device).to(jac.dtype) |
|
_jac[self.targets] = jac |
|
jac = self.agg_channel_func(_jac) |
|
assert jac.size(-3) == 1, jac.shape |
|
jac = jac.squeeze(-3)[..., 0, :, :] |
|
jac = jac.view(self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], |
|
self.B, self.p_mask_shape[-2], self.p_mask_shape[-1]) |
|
bs = torch.arange(0, self.B).long().to(jac.device) |
|
jac = jac[bs, :, :, bs, :, :] |
|
return jac |
|
|
|
def _confident_jacobian(self, jac): |
|
if self.confidence_thresh is None: |
|
return torch.ones_like(jac[:, None, ..., 0, 0]) |
|
conf = (jac.amax((-2, -1)) > self.confidence_thresh).float()[:, None] |
|
return conf |
|
|
|
def set_input(self, x, mask=None, timestamps=None): |
|
shape = x.shape |
|
if len(shape) == 4: |
|
x = x.unsqueeze(1) |
|
else: |
|
assert len(shape) == 5, \ |
|
"Input must be a movie of shape [B,T,C,H,W]" + \ |
|
"or a single frame of shape [B,C,H,W]" |
|
|
|
self.inp_shape = x.shape |
|
self.x = x |
|
self.B = self.inp_shape[0] |
|
self.T = self.inp_shape[1] |
|
self.C = self.inp_shape[2] |
|
if mask is not None: |
|
self.mask = mask |
|
|
|
if timestamps is not None: |
|
self.timestamps = timestamps |
|
|
|
def _preprocess(self, x): |
|
if self.imagenet_normalize_inputs: |
|
x = imagenet_normalize(x) |
|
if self.t_dim != 1: |
|
x = x.transpose(self.t_dim, self.c_dim) |
|
return x |
|
|
|
def _jacobian_to_flows(self, jac): |
|
if self.agg_power is None: |
|
jac = (jac == jac.amax((-2, -1), True)).float() |
|
else: |
|
jac = torch.pow(jac, self.agg_power) |
|
|
|
jac = jac.view(self.B * np.prod(self.a_mask_shape[-2:]), 1, 1, *self.p_mask_shape[-2:]) |
|
centroids = get_distribution_centroid(jac, normalize=False).view( |
|
self.B, self.a_mask_shape[-2], self.a_mask_shape[-1], 2) |
|
rescale = [self.a_mask_shape[-2] / self.p_mask_shape[-2], |
|
self.a_mask_shape[-1] / self.p_mask_shape[-1]] |
|
centroids = centroids * torch.tensor(rescale, device=centroids.device).view(1, 1, 1, 2) |
|
|
|
flows = centroids - \ |
|
coordinate_ims(1, 0, self.a_mask_shape[-2:], normalize=False).to(jac.device) |
|
flows = flows.permute(0, 3, 1, 2) |
|
px_scale = torch.tensor(self.aggregation_patch_size[-2:]).float().to(flows.device).view(1, 2, 1, 1) |
|
flows *= px_scale |
|
|
|
return flows |
|
|
|
def set_targets(self, targets=None, frame=-1): |
|
frame = frame % self.mask_shape[0] |
|
if targets is None: |
|
targets = self.get_mask_image(self.mask)[:, frame:frame + 1] |
|
else: |
|
assert len(targets.shape) == 4, targets.shape |
|
targets = targets[:, frame:frame + 1] |
|
self.targets = ~masking.upsample_masks(~targets, self.a_mask_shape[-2:]) |
|
|
|
def _get_mask_partition(self, mask): |
|
mask = self.get_mask_image(mask) |
|
mask_list = masking.partition_masks( |
|
mask[:, 1:], num_samples=self.S, leave_one_out=self.leave_one_out_sampling) |
|
return [torch.cat([mask[:, 0:1].view(m.size(0), -1), m], -1) |
|
for m in mask_list] |
|
|
|
def _compute_jacobian(self, y): |
|
perturbation_func = self._get_perturbation_func() |
|
jac = torch.autograd.functional.jacobian( |
|
perturbation_func, |
|
y, |
|
vectorize=False) |
|
jac = self._postprocess_jacobian(jac) |
|
return jac |
|
|
|
def _upsample_mask(self, mask): |
|
return masking.upsample_masks( |
|
mask.view(mask.size(0), -1, *self.mask_shape[-2:]).float(), self.inp_shape[-2:]) |
|
|
|
def get_mask_image(self, mask, upsample=False, invert=False, shape=None): |
|
if shape is None: |
|
shape = self.mask_shape |
|
mask = mask.view(-1, *shape) |
|
if upsample: |
|
mask = self._upsample_mask(mask) |
|
if invert: |
|
mask = 1 - mask |
|
return mask |
|
|
|
def forward(self, x, mask, targets=None): |
|
self.set_input(x, mask) |
|
y = self.get_perturbation_input(x) |
|
mask_list = self._get_mask_partition(mask) |
|
|
|
jacobian, flows, confident = [], [], [] |
|
for s, mask_sample in enumerate(mask_list): |
|
self.set_input(x, mask_sample) |
|
self.set_targets(targets) |
|
|
|
import time |
|
t1 = time.time() |
|
jac = self._compute_jacobian(y) |
|
conf_jac = masking.upsample_masks(self._confident_jacobian(jac), self.a_mask_shape[-2:]) |
|
jacobian.append(jac) |
|
confident.append(conf_jac) |
|
if not self.average_jacobian: |
|
flow = self._jacobian_to_flows(jac) * self.targets * conf_jac * \ |
|
masking.upsample_masks(self.get_mask_image(self.mask)[:, 1:], self.a_mask_shape[-2:]) |
|
flows.append(flow) |
|
t2 = time.time() |
|
print(t2 - t1) |
|
|
|
jacobian = torch.stack(jacobian, -1) |
|
confident = torch.stack(confident, -1) |
|
valid = torch.stack([masking.upsample_masks( |
|
self.get_mask_image(m)[:, 1:], self.a_mask_shape[-2:]) for m in mask_list], -1) |
|
valid = valid * confident |
|
|
|
if self.average_jacobian: |
|
_valid = valid[:, 0].unsqueeze(-2).unsqueeze(-2) |
|
jac = (jacobian * _valid.float()).sum(-1) / _valid.float().sum(-1).clamp(min=1) |
|
flows = self._jacobian_to_flows(jac) * \ |
|
masking.upsample_masks(_valid[:, None, ..., 0, 0, :].amax(-1).bool(), self.a_mask_shape[-2:]) |
|
if targets is not None: |
|
self.set_targets(targets) |
|
flows *= self.targets |
|
else: |
|
flows = torch.stack(flows, -1) |
|
flows = flows.sum(-1) / valid.float().sum(-1).clamp(min=1) |
|
|
|
valid = valid * (targets[:, -1:].unsqueeze(-1) if targets is not None else 1) |
|
|
|
return (jacobian, flows, valid) |