import torch.nn as nn class HeadProjectorResidual(nn.Module): def __init__( self, input_embedding_dim: int = 1000, output_embedding_dim: int = 512, dropout: float = 0.4 ): super().__init__() self.projection = nn.Linear(input_embedding_dim, output_embedding_dim) self.gelu = nn.GELU() self.fc = nn.Linear(output_embedding_dim, output_embedding_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(output_embedding_dim) def forward(self, x): projected = self.projection(x) x = self.gelu(projected) x = self.fc(x) x = self.dropout(x) x = x + projected x = self.layer_norm(x) return x