import torch.nn as nn from transformers import PreTrainedModel, AutoModelForSequenceClassification from transformers.modeling_outputs import SequenceClassifierOutput from .configuration_mlp import MLPConfig class MLP(PreTrainedModel): r""" A simple MLP model that takes a 3D input [batch_size, seq_length, embedding_size] and performs multi-label classification using BCE loss. """ config_class = MLPConfig def __init__(self, config: MLPConfig): super().__init__(config) self.config = config # Define an MLP stack layers = [] input_dim = config.embedding_size * config.sequence_length for _ in range(config.num_hidden_layers): layers.append(nn.Linear(input_dim, config.hidden_size)) layers.append(nn.ReLU()) layers.append(nn.Dropout(config.dropout)) input_dim = config.hidden_size # Final layer: hidden -> num_labels layers.append(nn.Linear(input_dim, config.num_labels)) self.mlp = nn.Sequential(*layers) # Initialize weights using standard HF utility self.post_init() def forward( self, inputs_embeds=None, labels=None, **kwargs ): """ Forward pass of the MLP. Args: inputs_embeds (torch.FloatTensor): A 3D tensor of shape [batch_size, seq_length, embedding_size]. labels (torch.FloatTensor): Multi-hot labels for multi-label classification, shape [batch_size, num_labels]. Returns: SequenceClassifierOutput with fields: - loss (optional) - logits - hidden_states (None) - attentions (None) """ # inputs_embeds is [B, L, E] # Flatten over seq_length if desired, or do a pooling: # Option A: Flatten everything: B x (L*E) B, L, E = inputs_embeds.shape # assert L == self.config.sequence_length and E == self.config.embedding_size x = inputs_embeds.reshape(B, L * E) # Option B: Mean-pool across tokens (comment out if you prefer flattening) # x = inputs_embeds.mean(dim=1) # shape: B x E # (If you do mean-pooling, remember to adjust 'input_dim' in the __init__ to E, not L*E) # Pass through MLP logits = self.mlp(x) # shape: [B, num_labels] loss = None if labels is not None: # For multi-label classification, use BCEWithLogitsLoss loss_fct = nn.BCEWithLogitsLoss() # Ensure labels is float loss = loss_fct(logits, labels.float()) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None ) AutoModelForSequenceClassification.register(MLPConfig, MLP)