|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from einops import rearrange, reduce |
|
|
|
|
|
def avg_pool_nd(dims, *args, **kwargs): |
|
""" |
|
Create a 1D, 2D, or 3D average pooling module. |
|
""" |
|
if dims == 1: |
|
return nn.AvgPool1d(*args, **kwargs) |
|
elif dims == 2: |
|
return nn.AvgPool2d(*args, **kwargs) |
|
elif dims == 3: |
|
return nn.AvgPool3d(*args, **kwargs) |
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
def conv_nd(dims, *args, **kwargs): |
|
""" |
|
Create a 1D, 2D, or 3D convolution module. |
|
""" |
|
if dims == 1: |
|
return nn.Conv1d(*args, **kwargs) |
|
elif dims == 2: |
|
return nn.Conv2d(*args, **kwargs) |
|
elif dims == 3: |
|
return nn.Conv3d(*args, **kwargs) |
|
raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
|
class Downsample(nn.Module): |
|
""" |
|
A downsampling layer with an optional convolution. |
|
:param channels: channels in the inputs and outputs. |
|
:param use_conv: a bool determining if a convolution is applied. |
|
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then |
|
downsampling occurs in the inner-two dimensions. |
|
""" |
|
|
|
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): |
|
super().__init__() |
|
self.channels = channels |
|
self.out_channels = out_channels or channels |
|
self.use_conv = use_conv |
|
self.dims = dims |
|
stride = 2 if dims != 3 else (1, 2, 2) |
|
if use_conv: |
|
self.op = conv_nd( |
|
dims, |
|
self.channels, |
|
self.out_channels, |
|
3, |
|
stride=stride, |
|
padding=padding, |
|
) |
|
else: |
|
assert self.channels == self.out_channels |
|
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) |
|
|
|
def forward(self, x): |
|
assert x.shape[1] == self.channels |
|
return self.op(x) |
|
|
|
|
|
class ResnetBlock(nn.Module): |
|
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True): |
|
super().__init__() |
|
ps = ksize // 2 |
|
if in_c != out_c or sk == False: |
|
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) |
|
else: |
|
|
|
self.in_conv = None |
|
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) |
|
self.act = nn.ReLU() |
|
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) |
|
self.bn1 = nn.BatchNorm2d(out_c) |
|
self.bn2 = nn.BatchNorm2d(out_c) |
|
if sk == False: |
|
|
|
self.skep = nn.Conv2d(out_c, out_c, ksize, 1, ps) |
|
else: |
|
self.skep = None |
|
|
|
self.down = down |
|
if self.down == True: |
|
self.down_opt = Downsample(in_c, use_conv=use_conv) |
|
|
|
def forward(self, x): |
|
if self.down == True: |
|
x = self.down_opt(x) |
|
if self.in_conv is not None: |
|
x = self.in_conv(x) |
|
|
|
h = self.bn1(x) |
|
h = self.act(h) |
|
h = self.block1(h) |
|
h = self.bn2(h) |
|
h = self.act(h) |
|
h = self.block2(h) |
|
if self.skep is not None: |
|
return h + self.skep(x) |
|
else: |
|
return h + x |
|
|
|
|
|
class VAESpatialEmulator(nn.Module): |
|
def __init__(self, kernel_size=(8, 8)): |
|
super().__init__() |
|
self.kernel_size = kernel_size |
|
|
|
def forward(self, x): |
|
""" |
|
x: torch.Tensor: shape [B C T H W] |
|
""" |
|
Hp, Wp = self.kernel_size |
|
H, W = x.shape[-2], x.shape[-1] |
|
valid_h = H - H % Hp |
|
valid_w = W - W % Wp |
|
x = x[..., :valid_h, :valid_w] |
|
x = rearrange( |
|
x, |
|
"B C T (Nh Hp) (Nw Wp) -> B (Hp Wp C) T Nh Nw", |
|
Hp=Hp, |
|
Wp=Wp, |
|
) |
|
return x |
|
|
|
|
|
class VAETemporalEmulator(nn.Module): |
|
def __init__(self, micro_frame_size, kernel_size=4): |
|
super().__init__() |
|
self.micro_frame_size = micro_frame_size |
|
self.kernel_size = kernel_size |
|
|
|
def forward(self, x_z): |
|
""" |
|
x_z: torch.Tensor: shape [B C T H W] |
|
""" |
|
|
|
z_list = [] |
|
for i in range(0, x_z.shape[2], self.micro_frame_size): |
|
x_z_bs = x_z[:, :, i : i + self.micro_frame_size] |
|
z_list.append(x_z_bs[:, :, 0:1]) |
|
x_z_bs = x_z_bs[:, :, 1:] |
|
t_valid = x_z_bs.shape[2] - x_z_bs.shape[2] % self.kernel_size |
|
x_z_bs = x_z_bs[:, :, :t_valid] |
|
x_z_bs = reduce(x_z_bs, "B C (T n) H W -> B C T H W", n=self.kernel_size, reduction="mean") |
|
z_list.append(x_z_bs) |
|
z = torch.cat(z_list, dim=2) |
|
return z |
|
|
|
|
|
class TrajExtractor(nn.Module): |
|
def __init__( |
|
self, |
|
vae_downsize=(4, 8, 8), |
|
patch_size=2, |
|
channels=[320, 640, 1280, 1280], |
|
nums_rb=3, |
|
cin=2, |
|
ksize=3, |
|
sk=False, |
|
use_conv=True, |
|
): |
|
super(TrajExtractor, self).__init__() |
|
self.vae_downsize = vae_downsize |
|
|
|
self.downsize_patchify = nn.PixelUnshuffle(patch_size) |
|
self.patch_size = (1, patch_size, patch_size) |
|
self.channels = channels |
|
self.nums_rb = nums_rb |
|
self.body = [] |
|
for i in range(len(channels)): |
|
for j in range(nums_rb): |
|
if (i != 0) and (j == 0): |
|
self.body.append( |
|
ResnetBlock( |
|
channels[i - 1], |
|
channels[i], |
|
down=False, |
|
ksize=ksize, |
|
sk=sk, |
|
use_conv=use_conv, |
|
) |
|
) |
|
else: |
|
self.body.append( |
|
ResnetBlock( |
|
channels[i], |
|
channels[i], |
|
down=False, |
|
ksize=ksize, |
|
sk=sk, |
|
use_conv=use_conv, |
|
) |
|
) |
|
self.body = nn.ModuleList(self.body) |
|
cin_ = cin * patch_size**2 |
|
self.conv_in = nn.Conv2d(cin_, channels[0], 3, 1, 1) |
|
|
|
|
|
def conv_init(module): |
|
if isinstance(module, (nn.Conv2d, nn.Conv1d)): |
|
nn.init.kaiming_normal_(module.weight, nonlinearity="relu") |
|
if module.bias is not None: |
|
nn.init.constant_(module.bias, 0) |
|
|
|
self.apply(conv_init) |
|
|
|
def forward(self, x): |
|
""" |
|
x: torch.Tensor: shape [B C T H W] |
|
""" |
|
|
|
T, H, W = x.shape[-3:] |
|
if W % self.patch_size[2] != 0: |
|
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2])) |
|
if H % self.patch_size[1] != 0: |
|
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1])) |
|
if T % self.patch_size[0] != 0: |
|
x = F.pad( |
|
x, |
|
(0, 0, 0, 0, 0, self.patch_size[0] - T % self.patch_size[0]), |
|
) |
|
x = rearrange(x, "B C T H W -> (B T) C H W") |
|
x = self.downsize_patchify(x) |
|
|
|
|
|
features = [] |
|
x = self.conv_in(x) |
|
for i in range(len(self.channels)): |
|
for j in range(self.nums_rb): |
|
idx = i * self.nums_rb + j |
|
x = self.body[idx](x) |
|
features.append(x) |
|
|
|
return features |
|
|
|
|
|
class FloatGroupNorm(nn.GroupNorm): |
|
def forward(self, x): |
|
return super().forward(x.to(self.bias.dtype)).type(x.dtype) |
|
|
|
|
|
def zero_module(module): |
|
""" |
|
Zero out the parameters of a module and return it. |
|
""" |
|
for p in module.parameters(): |
|
p.detach().zero_() |
|
return module |
|
|
|
|
|
class MGF(nn.Module): |
|
def __init__(self, flow_in_channel=128, out_channels=1152): |
|
super().__init__() |
|
self.out_channels = out_channels |
|
self.flow_gamma_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1) |
|
self.flow_gamma_temporal = zero_module( |
|
nn.Conv1d( |
|
self.out_channels // 4, |
|
self.out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
padding_mode="replicate", |
|
) |
|
) |
|
self.flow_beta_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1) |
|
self.flow_beta_temporal = zero_module( |
|
nn.Conv1d( |
|
self.out_channels // 4, |
|
self.out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
padding_mode="replicate", |
|
) |
|
) |
|
self.flow_cond_norm = FloatGroupNorm(32, self.out_channels) |
|
|
|
def forward(self, h, flow, T): |
|
if flow is not None: |
|
gamma_flow = self.flow_gamma_spatial(flow) |
|
beta_flow = self.flow_beta_spatial(flow) |
|
_, _, hh, wh = beta_flow.shape |
|
|
|
if gamma_flow.shape[0] == 1: |
|
gamma_flow = rearrange(gamma_flow, "b c h w -> b c (h w)") |
|
beta_flow = rearrange(beta_flow, "b c h w -> b c (h w)") |
|
gamma_flow = self.flow_gamma_temporal(gamma_flow) |
|
beta_flow = self.flow_beta_temporal(beta_flow) |
|
gamma_flow = rearrange(gamma_flow, "b c (h w) -> b c h w", h=hh, w=wh) |
|
beta_flow = rearrange(beta_flow, "b c (h w) -> b c h w", h=hh, w=wh) |
|
else: |
|
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T) |
|
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T) |
|
gamma_flow = self.flow_gamma_temporal(gamma_flow) |
|
beta_flow = self.flow_beta_temporal(beta_flow) |
|
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) |
|
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh) |
|
|
|
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow |
|
return h |
|
|