import torch from torch import nn from typing import Optional from dataclasses import dataclass from transformers import PreTrainedModel from .configuration_mlp import MLPConfig from transformers.utils import ModelOutput from transformers.activations import ACT2FN @dataclass class MLPOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: Optional[torch.FloatTensor] = None class MLPPreTrainedModel(PreTrainedModel): config_class = MLPConfig def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() class MLPModel(MLPPreTrainedModel): def __init__(self, config): super().__init__(config) self.act_fn = ACT2FN[config.hidden_act] iho = [config.input_size, *config.hidden_size, config.output_size] self.linears = nn.ModuleList([ nn.Linear(iho[i], iho[i+1]) for i in range(config.num_hidden_layers + 1) ]) self.loss_fn = nn.CrossEntropyLoss(reduce="mean") # Initialize weights and apply final processing self.post_init() def forward(self, inputs, labels=None): for i in range(len(self.linears) - 1): inputs = self.act_fn(self.linears[i](inputs)) logits = self.linears[-1](inputs) loss = None if labels is None: return ModelOutput(loss=loss, logits=logits) else: loss = self.loss_fn(logits, labels) return ModelOutput(loss=loss, logits=logits)