|
import torch |
|
from torch import nn |
|
from transformers import ElectraPreTrainedModel, ElectraModel |
|
|
|
|
|
class SwishGLU(nn.Module): |
|
def __init__(self, input_dim: int, output_dim: int): |
|
super(SwishGLU, self).__init__() |
|
self.projection = nn.Linear(input_dim, 2 * output_dim) |
|
self.activation = nn.SiLU() |
|
|
|
def forward(self, x): |
|
x_proj_gate = self.projection(x) |
|
projected, gate = x_proj_gate.tensor_split(2, dim=-1) |
|
return projected * self.activation(gate) |
|
|
|
|
|
|
|
class PoolingLayer(nn.Module): |
|
def __init__(self, pooling_type='cls'): |
|
super().__init__() |
|
self.pooling_type = pooling_type |
|
|
|
def forward(self, last_hidden_state, attention_mask): |
|
if self.pooling_type == 'cls': |
|
return last_hidden_state[:, 0, :] |
|
elif self.pooling_type == 'mean': |
|
|
|
return (last_hidden_state * attention_mask.unsqueeze(-1)).sum(1) / attention_mask.sum(-1).unsqueeze(-1) |
|
elif self.pooling_type == 'max': |
|
|
|
return torch.max(last_hidden_state * attention_mask.unsqueeze(-1), dim=1)[0] |
|
else: |
|
raise ValueError(f"Unknown pooling method: {self.pooling_type}") |
|
|
|
|
|
|
|
class Classifier(nn.Module): |
|
def __init__(self, input_dim, hidden_dim, hidden_activation, num_layers, n_classes, dropout_rate=0.0): |
|
super().__init__() |
|
layers = [] |
|
layers.append(nn.Linear(input_dim, hidden_dim)) |
|
layers.append(hidden_activation) |
|
if dropout_rate > 0: |
|
layers.append(nn.Dropout(dropout_rate)) |
|
|
|
for _ in range(num_layers - 1): |
|
layers.append(nn.Linear(hidden_dim, hidden_dim)) |
|
layers.append(hidden_activation) |
|
if dropout_rate > 0: |
|
layers.append(nn.Dropout(dropout_rate)) |
|
|
|
layers.append(nn.Linear(hidden_dim, n_classes)) |
|
self.layers = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.layers(x) |
|
|
|
|
|
|
|
class ElectraClassifier(ElectraPreTrainedModel): |
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.electra = ElectraModel(config) |
|
|
|
if hasattr(self.electra, 'pooler'): |
|
self.electra.pooler = None |
|
|
|
self.pooling = PoolingLayer(pooling_type=config.pooling) |
|
|
|
|
|
activation_name = config.hidden_activation |
|
if activation_name == 'SwishGLU': |
|
hidden_activation = SwishGLU( |
|
input_dim=config.hidden_dim, |
|
output_dim=config.hidden_dim |
|
) |
|
else: |
|
activation_class = getattr(nn, activation_name) |
|
hidden_activation = activation_class() |
|
|
|
self.classifier = Classifier( |
|
input_dim=config.hidden_size, |
|
hidden_dim=config.hidden_dim, |
|
hidden_activation=hidden_activation, |
|
num_layers=config.num_layers, |
|
n_classes=config.num_labels, |
|
dropout_rate=config.dropout_rate |
|
) |
|
self.init_weights() |
|
|
|
def forward(self, input_ids=None, attention_mask=None, **kwargs): |
|
outputs = self.electra(input_ids, attention_mask=attention_mask, **kwargs) |
|
pooled_output = self.pooling(outputs.last_hidden_state, attention_mask) |
|
logits = self.classifier(pooled_output) |
|
return logits |
|
|