wmpscc
add
7d1312d
import torch
from torch import nn
from functools import partial
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange
pair = lambda x: x if isinstance(x, tuple) else (x, x)
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x):
return self.fn(self.norm(x)) + x
def FeedForward(dim, expansion_factor=4, dropout=0., dense=nn.Linear):
inner_dim = int(dim * expansion_factor)
return nn.Sequential(
dense(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
dense(inner_dim, dim),
nn.Dropout(dropout)
)
class MappingSub2W(nn.Module):
def __init__(self, N=8, dim=512, depth=6, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1):
super(MappingSub2W, self).__init__()
num_patches = N * 34
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
self.layer = nn.Sequential(
Rearrange('b c h w -> b (c h) w'),
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
Rearrange('b c h -> b h c'),
nn.Linear(34 * N, 34 * N),
nn.LayerNorm(34 * N),
nn.GELU(),
nn.Linear(34 * N, N),
Rearrange('b h c -> b c h')
)
def forward(self, x):
return self.layer(x)
class MappingW2Sub(nn.Module):
def __init__(self, N=8, dim=512, depth=8, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1):
super(MappingW2Sub, self).__init__()
self.N = N
num_patches = N * 34
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
self.layer = nn.Sequential(
Rearrange('b c h -> b h c'),
nn.Linear(N, num_patches),
Rearrange('b h c -> b c h'),
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim)
)
self.mu_fc = nn.Sequential(
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
) for _ in range(2)],
nn.LayerNorm(dim),
nn.Tanh(),
Rearrange('b c h -> b (c h)')
)
self.var_fc = nn.Sequential(
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
) for _ in range(2)],
nn.LayerNorm(dim),
nn.Tanh(),
Rearrange('b c h -> b (c h)')
)
def reparameterize(self, mu, logvar):
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, x):
f = self.layer(x)
mu = self.mu_fc(f)
var = self.var_fc(f)
z = self.reparameterize(mu, var)
z = rearrange(z, 'a (b c d) -> a b c d', b=self.N, c=34)
return rearrange(mu, 'a (b c d) -> a b c d', b=self.N, c=34), rearrange(var, 'a (b c d) -> a b c d',
b=self.N, c=34), z
class HeadEncoder(nn.Module):
def __init__(self, N=8, dim=512, depth=2, expansion_factor=4., expansion_factor_token=0.5, dropout=0.1):
super(HeadEncoder, self).__init__()
channels = [32, 64, 64, 64]
self.N = N
num_patches = N
chan_first, chan_last = partial(nn.Conv1d, kernel_size=1), nn.Linear
self.s1 = nn.Sequential(
nn.Conv2d(channels[0], channels[1], kernel_size=5, padding=2, stride=2),
nn.BatchNorm2d(channels[1]),
nn.LeakyReLU(),
nn.Conv2d(channels[1], channels[2], kernel_size=5, padding=2, stride=2),
nn.BatchNorm2d(channels[2]),
nn.LeakyReLU(),
nn.Conv2d(channels[2], channels[3], kernel_size=5, padding=2, stride=2),
nn.BatchNorm2d(channels[3]),
nn.LeakyReLU())
self.mlp1 = nn.Linear(channels[3] * 8 * 8, 512)
self.up_N = nn.Linear(1, N)
self.mu_fc = nn.Sequential(
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
nn.Tanh()
)
self.var_fc = nn.Sequential(
*[nn.Sequential(
PreNormResidual(dim, FeedForward(num_patches, expansion_factor, dropout, chan_first)),
PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last))
) for _ in range(depth)],
nn.LayerNorm(dim),
nn.Tanh()
)
def reparameterize(self, mu, logvar):
"""
Reparameterization trick to sample from N(mu, var) from
N(0,1).
:param mu: (Tensor) Mean of the latent Gaussian [B x D]
:param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
:return: (Tensor) [B x D]
"""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps * std + mu
def forward(self, x):
feature = self.s1(x)
s2 = torch.flatten(feature, start_dim=1)
s2 = self.mlp1(s2).unsqueeze(2)
s2 = self.up_N(s2)
s2 = rearrange(s2, 'b h c -> b c h')
mu = self.mu_fc(s2)
var = self.var_fc(s2)
z = self.reparameterize(mu, var)
return mu, var, z
class RegionEncoder(nn.Module):
def __init__(self, N=8):
super(RegionEncoder, self).__init__()
channels = [8, 16, 32, 32, 64, 64]
self.s1 = nn.Conv2d(3, channels[0], kernel_size=3, padding=1, stride=2)
self.s2 = nn.Sequential(
nn.Conv2d(channels[0], channels[1], kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(channels[1]),
nn.LeakyReLU(),
nn.Conv2d(channels[1], channels[2], kernel_size=3, padding=1, stride=2),
nn.BatchNorm2d(channels[2]),
nn.LeakyReLU()
)
self.heads = nn.ModuleList()
for i in range(34):
self.heads.append(HeadEncoder(N=N))
def forward(self, x, all_mask=None):
s1 = self.s1(x)
s2 = self.s2(s1)
result = []
mus = []
log_vars = []
for i, head in enumerate(self.heads):
m = all_mask[:, i, :].unsqueeze(1)
mu, var, z = head(s2 * m)
result.append(z.unsqueeze(2))
mus.append(mu.unsqueeze(2))
log_vars.append(var.unsqueeze(2))
return torch.cat(mus, dim=2), torch.cat(log_vars, dim=2), torch.cat(result, dim=2)