ianpan's picture
Initial commit
231edce
# Copyright (c) Facebook, Inc. and its affiliates.
import torch
import torch.nn as nn
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop_rate=0.0,
):
super().__init__()
self.drop_rate = drop_rate
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
if self.drop_rate > 0.0:
self.drop = nn.Dropout(drop_rate)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
if self.drop_rate > 0.0:
x = self.drop(x)
x = self.fc2(x)
if self.drop_rate > 0.0:
x = self.drop(x)
return x
class Permute(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
def forward(self, x):
return x.permute(*self.dims)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Stochastic Depth per sample.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
mask.floor_() # binarize
output = x.div(keep_prob) * mask
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class TwoStreamFusion(nn.Module):
def __init__(self, mode, dim=None, kernel=3, padding=1):
"""
A general constructor for neural modules fusing two equal sized tensors
in forward. Following options are supported:
"add" / "max" / "min" / "avg" : respective operations on the two halves.
"concat" : NOOP.
"concat_linear_{dim_mult}_{drop_rate}" : MLP to fuse with hidden dim "dim_mult"
(optional, def 1.) higher than input dim
with optional dropout "drop_rate" (def: 0.)
"ln+concat_linear_{dim_mult}_{drop_rate}" : perform MLP after layernorm on the input.
"""
super().__init__()
self.mode = mode
if mode == "add":
self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).sum(
dim=0
)
elif mode == "max":
self.fuse_fn = (
lambda x: torch.stack(torch.chunk(x, 2, dim=2))
.max(dim=0)
.values
)
elif mode == "min":
self.fuse_fn = (
lambda x: torch.stack(torch.chunk(x, 2, dim=2))
.min(dim=0)
.values
)
elif mode == "avg":
self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).mean(
dim=0
)
elif mode == "concat":
# x itself is the channel concat version
self.fuse_fn = lambda x: x
elif "concat_linear" in mode:
if len(mode.split("_")) == 2:
dim_mult = 1.0
drop_rate = 0.0
elif len(mode.split("_")) == 3:
dim_mult = float(mode.split("_")[-1])
drop_rate = 0.0
elif len(mode.split("_")) == 4:
dim_mult = float(mode.split("_")[-2])
drop_rate = float(mode.split("_")[-1])
else:
raise NotImplementedError
if mode.split("+")[0] == "ln":
self.fuse_fn = nn.Sequential(
nn.LayerNorm(dim),
Mlp(
in_features=dim,
hidden_features=int(dim * dim_mult),
act_layer=nn.GELU,
out_features=dim,
drop_rate=drop_rate,
),
)
else:
self.fuse_fn = Mlp(
in_features=dim,
hidden_features=int(dim * dim_mult),
act_layer=nn.GELU,
out_features=dim,
drop_rate=drop_rate,
)
else:
raise NotImplementedError
def forward(self, x):
if "concat_linear" in self.mode:
return self.fuse_fn(x) + x
else:
return self.fuse_fn(x)