jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding,
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResnetBlock(nn.Module):
def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
super().__init__()
ps = ksize // 2
if in_c != out_c or sk == False:
self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
else:
# print('n_in')
self.in_conv = None
self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
self.act = nn.ReLU()
self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
self.bn1 = nn.BatchNorm2d(out_c)
self.bn2 = nn.BatchNorm2d(out_c)
if sk == False:
# self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) # edit by zhouxiawang
self.skep = nn.Conv2d(out_c, out_c, ksize, 1, ps)
else:
self.skep = None
self.down = down
if self.down == True:
self.down_opt = Downsample(in_c, use_conv=use_conv)
def forward(self, x):
if self.down == True:
x = self.down_opt(x)
if self.in_conv is not None: # edit
x = self.in_conv(x)
h = self.bn1(x)
h = self.act(h)
h = self.block1(h)
h = self.bn2(h)
h = self.act(h)
h = self.block2(h)
if self.skep is not None:
return h + self.skep(x)
else:
return h + x
class VAESpatialEmulator(nn.Module):
def __init__(self, kernel_size=(8, 8)):
super().__init__()
self.kernel_size = kernel_size
def forward(self, x):
"""
x: torch.Tensor: shape [B C T H W]
"""
Hp, Wp = self.kernel_size
H, W = x.shape[-2], x.shape[-1]
valid_h = H - H % Hp
valid_w = W - W % Wp
x = x[..., :valid_h, :valid_w]
x = rearrange(
x,
"B C T (Nh Hp) (Nw Wp) -> B (Hp Wp C) T Nh Nw",
Hp=Hp,
Wp=Wp,
)
return x
class VAETemporalEmulator(nn.Module):
def __init__(self, micro_frame_size, kernel_size=4):
super().__init__()
self.micro_frame_size = micro_frame_size
self.kernel_size = kernel_size
def forward(self, x_z):
"""
x_z: torch.Tensor: shape [B C T H W]
"""
z_list = []
for i in range(0, x_z.shape[2], self.micro_frame_size):
x_z_bs = x_z[:, :, i : i + self.micro_frame_size]
z_list.append(x_z_bs[:, :, 0:1])
x_z_bs = x_z_bs[:, :, 1:]
t_valid = x_z_bs.shape[2] - x_z_bs.shape[2] % self.kernel_size
x_z_bs = x_z_bs[:, :, :t_valid]
x_z_bs = reduce(x_z_bs, "B C (T n) H W -> B C T H W", n=self.kernel_size, reduction="mean")
z_list.append(x_z_bs)
z = torch.cat(z_list, dim=2)
return z
class TrajExtractor(nn.Module):
def __init__(
self,
vae_downsize=(4, 8, 8),
patch_size=2,
channels=[320, 640, 1280, 1280],
nums_rb=3,
cin=2,
ksize=3,
sk=False,
use_conv=True,
):
super(TrajExtractor, self).__init__()
self.vae_downsize = vae_downsize
# self.vae_spatial_emulator = VAESpatialEmulator(kernel_size=vae_downsize[-2:])
self.downsize_patchify = nn.PixelUnshuffle(patch_size)
self.patch_size = (1, patch_size, patch_size)
self.channels = channels
self.nums_rb = nums_rb
self.body = []
for i in range(len(channels)):
for j in range(nums_rb):
if (i != 0) and (j == 0):
self.body.append(
ResnetBlock(
channels[i - 1],
channels[i],
down=False,
ksize=ksize,
sk=sk,
use_conv=use_conv,
)
)
else:
self.body.append(
ResnetBlock(
channels[i],
channels[i],
down=False,
ksize=ksize,
sk=sk,
use_conv=use_conv,
)
)
self.body = nn.ModuleList(self.body)
cin_ = cin * patch_size**2
self.conv_in = nn.Conv2d(cin_, channels[0], 3, 1, 1)
# Initialize weights
def conv_init(module):
if isinstance(module, (nn.Conv2d, nn.Conv1d)):
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(conv_init)
def forward(self, x):
"""
x: torch.Tensor: shape [B C T H W]
"""
# downsize
T, H, W = x.shape[-3:]
if W % self.patch_size[2] != 0:
x = F.pad(x, (0, self.patch_size[2] - W % self.patch_size[2]))
if H % self.patch_size[1] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[1] - H % self.patch_size[1]))
if T % self.patch_size[0] != 0:
x = F.pad(
x,
(0, 0, 0, 0, 0, self.patch_size[0] - T % self.patch_size[0]),
)
x = rearrange(x, "B C T H W -> (B T) C H W")
x = self.downsize_patchify(x)
# extract features
features = []
x = self.conv_in(x)
for i in range(len(self.channels)):
for j in range(self.nums_rb):
idx = i * self.nums_rb + j
x = self.body[idx](x)
features.append(x)
return features
class FloatGroupNorm(nn.GroupNorm):
def forward(self, x):
return super().forward(x.to(self.bias.dtype)).type(x.dtype)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class MGF(nn.Module):
def __init__(self, flow_in_channel=128, out_channels=1152):
super().__init__()
self.out_channels = out_channels
self.flow_gamma_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1)
self.flow_gamma_temporal = zero_module(
nn.Conv1d(
self.out_channels // 4,
self.out_channels,
kernel_size=3,
stride=1,
padding=1,
padding_mode="replicate",
)
)
self.flow_beta_spatial = nn.Conv2d(flow_in_channel, self.out_channels // 4, 3, padding=1)
self.flow_beta_temporal = zero_module(
nn.Conv1d(
self.out_channels // 4,
self.out_channels,
kernel_size=3,
stride=1,
padding=1,
padding_mode="replicate",
)
)
self.flow_cond_norm = FloatGroupNorm(32, self.out_channels)
def forward(self, h, flow, T):
if flow is not None:
gamma_flow = self.flow_gamma_spatial(flow)
beta_flow = self.flow_beta_spatial(flow)
_, _, hh, wh = beta_flow.shape
if gamma_flow.shape[0] == 1: # Check if batch size is 1
gamma_flow = rearrange(gamma_flow, "b c h w -> b c (h w)")
beta_flow = rearrange(beta_flow, "b c h w -> b c (h w)")
gamma_flow = self.flow_gamma_temporal(gamma_flow)
beta_flow = self.flow_beta_temporal(beta_flow)
gamma_flow = rearrange(gamma_flow, "b c (h w) -> b c h w", h=hh, w=wh)
beta_flow = rearrange(beta_flow, "b c (h w) -> b c h w", h=hh, w=wh)
else:
gamma_flow = rearrange(gamma_flow, "(b f) c h w -> (b h w) c f", f=T)
beta_flow = rearrange(beta_flow, "(b f) c h w -> (b h w) c f", f=T)
gamma_flow = self.flow_gamma_temporal(gamma_flow)
beta_flow = self.flow_beta_temporal(beta_flow)
gamma_flow = rearrange(gamma_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
beta_flow = rearrange(beta_flow, "(b h w) c f -> (b f) c h w", h=hh, w=wh)
h = h + self.flow_cond_norm(h) * gamma_flow + beta_flow
return h