|
import torch |
|
import torch.nn as nn |
|
|
|
from unidepth.utils.misc import default |
|
from .activation import SwiGLU |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__( |
|
self, |
|
input_dim: int, |
|
expansion: int = 4, |
|
dropout: float = 0.0, |
|
gated: bool = False, |
|
output_dim: int | None = None, |
|
): |
|
super().__init__() |
|
if gated: |
|
expansion = int(expansion * 2 / 3) |
|
hidden_dim = int(input_dim * expansion) |
|
output_dim = default(output_dim, input_dim) |
|
self.norm = nn.LayerNorm(input_dim) |
|
self.proj1 = nn.Linear(input_dim, hidden_dim) |
|
self.proj2 = nn.Linear(hidden_dim, output_dim) |
|
self.act = nn.GELU() if not gated else SwiGLU() |
|
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.norm(x) |
|
x = self.proj1(x) |
|
x = self.act(x) |
|
x = self.proj2(x) |
|
x = self.dropout(x) |
|
return x |
|
|