mlp / modeling_mlp.py
denizyuret's picture
*** empty log message ***
76d26f0
import torch.nn as nn
from transformers import PreTrainedModel, AutoModelForSequenceClassification
from transformers.modeling_outputs import SequenceClassifierOutput
from .configuration_mlp import MLPConfig
class MLP(PreTrainedModel):
r"""
A simple MLP model that takes a 3D input [batch_size, seq_length, embedding_size]
and performs multi-label classification using BCE loss.
"""
config_class = MLPConfig
def __init__(self, config: MLPConfig):
super().__init__(config)
self.config = config
# Define an MLP stack
layers = []
input_dim = config.embedding_size * config.sequence_length
for _ in range(config.num_hidden_layers):
layers.append(nn.Linear(input_dim, config.hidden_size))
layers.append(nn.ReLU())
layers.append(nn.Dropout(config.dropout))
input_dim = config.hidden_size
# Final layer: hidden -> num_labels
layers.append(nn.Linear(input_dim, config.num_labels))
self.mlp = nn.Sequential(*layers)
# Initialize weights using standard HF utility
self.post_init()
def forward(
self,
inputs_embeds=None,
labels=None,
**kwargs
):
"""
Forward pass of the MLP.
Args:
inputs_embeds (torch.FloatTensor):
A 3D tensor of shape [batch_size, seq_length, embedding_size].
labels (torch.FloatTensor):
Multi-hot labels for multi-label classification, shape [batch_size, num_labels].
Returns:
SequenceClassifierOutput with fields:
- loss (optional)
- logits
- hidden_states (None)
- attentions (None)
"""
# inputs_embeds is [B, L, E]
# Flatten over seq_length if desired, or do a pooling:
# Option A: Flatten everything: B x (L*E)
B, L, E = inputs_embeds.shape
# assert L == self.config.sequence_length and E == self.config.embedding_size
x = inputs_embeds.reshape(B, L * E)
# Option B: Mean-pool across tokens (comment out if you prefer flattening)
# x = inputs_embeds.mean(dim=1) # shape: B x E
# (If you do mean-pooling, remember to adjust 'input_dim' in the __init__ to E, not L*E)
# Pass through MLP
logits = self.mlp(x) # shape: [B, num_labels]
loss = None
if labels is not None:
# For multi-label classification, use BCEWithLogitsLoss
loss_fct = nn.BCEWithLogitsLoss()
# Ensure labels is float
loss = loss_fct(logits, labels.float())
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None
)
AutoModelForSequenceClassification.register(MLPConfig, MLP)