|
import torch.nn as nn |
|
from transformers import PreTrainedModel, AutoModelForSequenceClassification |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from .configuration_mlp import MLPConfig |
|
|
|
|
|
class MLP(PreTrainedModel): |
|
r""" |
|
A simple MLP model that takes a 3D input [batch_size, seq_length, embedding_size] |
|
and performs multi-label classification using BCE loss. |
|
""" |
|
config_class = MLPConfig |
|
|
|
def __init__(self, config: MLPConfig): |
|
super().__init__(config) |
|
self.config = config |
|
|
|
|
|
layers = [] |
|
input_dim = config.embedding_size * config.sequence_length |
|
for _ in range(config.num_hidden_layers): |
|
layers.append(nn.Linear(input_dim, config.hidden_size)) |
|
layers.append(nn.ReLU()) |
|
layers.append(nn.Dropout(config.dropout)) |
|
input_dim = config.hidden_size |
|
|
|
layers.append(nn.Linear(input_dim, config.num_labels)) |
|
|
|
self.mlp = nn.Sequential(*layers) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
inputs_embeds=None, |
|
labels=None, |
|
**kwargs |
|
): |
|
""" |
|
Forward pass of the MLP. |
|
|
|
Args: |
|
inputs_embeds (torch.FloatTensor): |
|
A 3D tensor of shape [batch_size, seq_length, embedding_size]. |
|
labels (torch.FloatTensor): |
|
Multi-hot labels for multi-label classification, shape [batch_size, num_labels]. |
|
|
|
Returns: |
|
SequenceClassifierOutput with fields: |
|
- loss (optional) |
|
- logits |
|
- hidden_states (None) |
|
- attentions (None) |
|
""" |
|
|
|
|
|
|
|
B, L, E = inputs_embeds.shape |
|
|
|
x = inputs_embeds.reshape(B, L * E) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logits = self.mlp(x) |
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fct = nn.BCEWithLogitsLoss() |
|
|
|
loss = loss_fct(logits, labels.float()) |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=None, |
|
attentions=None |
|
) |
|
|
|
|
|
AutoModelForSequenceClassification.register(MLPConfig, MLP) |
|
|