# Copyright (c) Facebook, Inc. and its affiliates. import torch import torch.nn as nn class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop_rate=0.0, ): super().__init__() self.drop_rate = drop_rate out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) if self.drop_rate > 0.0: self.drop = nn.Dropout(drop_rate) def forward(self, x): x = self.fc1(x) x = self.act(x) if self.drop_rate > 0.0: x = self.drop(x) x = self.fc2(x) if self.drop_rate > 0.0: x = self.drop(x) return x class Permute(nn.Module): def __init__(self, dims): super().__init__() self.dims = dims def forward(self, x): return x.permute(*self.dims) def drop_path(x, drop_prob: float = 0.0, training: bool = False): """ Stochastic Depth per sample. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) mask.floor_() # binarize output = x.div(keep_prob) * mask return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class TwoStreamFusion(nn.Module): def __init__(self, mode, dim=None, kernel=3, padding=1): """ A general constructor for neural modules fusing two equal sized tensors in forward. Following options are supported: "add" / "max" / "min" / "avg" : respective operations on the two halves. "concat" : NOOP. "concat_linear_{dim_mult}_{drop_rate}" : MLP to fuse with hidden dim "dim_mult" (optional, def 1.) higher than input dim with optional dropout "drop_rate" (def: 0.) "ln+concat_linear_{dim_mult}_{drop_rate}" : perform MLP after layernorm on the input. """ super().__init__() self.mode = mode if mode == "add": self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).sum( dim=0 ) elif mode == "max": self.fuse_fn = ( lambda x: torch.stack(torch.chunk(x, 2, dim=2)) .max(dim=0) .values ) elif mode == "min": self.fuse_fn = ( lambda x: torch.stack(torch.chunk(x, 2, dim=2)) .min(dim=0) .values ) elif mode == "avg": self.fuse_fn = lambda x: torch.stack(torch.chunk(x, 2, dim=2)).mean( dim=0 ) elif mode == "concat": # x itself is the channel concat version self.fuse_fn = lambda x: x elif "concat_linear" in mode: if len(mode.split("_")) == 2: dim_mult = 1.0 drop_rate = 0.0 elif len(mode.split("_")) == 3: dim_mult = float(mode.split("_")[-1]) drop_rate = 0.0 elif len(mode.split("_")) == 4: dim_mult = float(mode.split("_")[-2]) drop_rate = float(mode.split("_")[-1]) else: raise NotImplementedError if mode.split("+")[0] == "ln": self.fuse_fn = nn.Sequential( nn.LayerNorm(dim), Mlp( in_features=dim, hidden_features=int(dim * dim_mult), act_layer=nn.GELU, out_features=dim, drop_rate=drop_rate, ), ) else: self.fuse_fn = Mlp( in_features=dim, hidden_features=int(dim * dim_mult), act_layer=nn.GELU, out_features=dim, drop_rate=drop_rate, ) else: raise NotImplementedError def forward(self, x): if "concat_linear" in self.mode: return self.fuse_fn(x) + x else: return self.fuse_fn(x)