import torch.nn as nn import torch.nn.functional as F from torch import Tensor # NOT CURRENTLY USED class SwiGLU(nn.Module): def __init__(self) -> None: super().__init__() def forward(self, x: Tensor) -> Tensor: x1, x2 = x.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return hidden class FFN(nn.Module): def __init__(self, in_proj, activation, out_proj) -> None: super().__init__() self.in_proj = in_proj self.activation = activation self.out_proj = out_proj def forward(self, x: Tensor) -> Tensor: x = self.in_proj(x) x = self.activation(x) x = self.out_proj(x) return x