|
from transformers import PreTrainedModel |
|
import torch.nn as nn |
|
import numpy as np |
|
from torch import tensor, float32 |
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer |
|
from .model_config import TabularTransformerConfig |
|
|
|
class FleshkaTabularTransformer(PreTrainedModel): |
|
config_class = TabularTransformerConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.config = config |
|
self.input_proj = nn.Linear(config.input_dim, config.d_model) |
|
|
|
|
|
encoder_layers = TransformerEncoderLayer( |
|
d_model=config.d_model, |
|
nhead=config.nhead, |
|
dim_feedforward=config.d_model * 4, |
|
dropout=config.dropout, |
|
activation="gelu", |
|
batch_first=True |
|
) |
|
self.transformer = TransformerEncoder(encoder_layers, config.num_layers) |
|
|
|
|
|
self.head = nn.Sequential( |
|
nn.LayerNorm(config.d_model), |
|
nn.Linear(config.d_model, 1) |
|
) |
|
|
|
self._init_weights() |
|
|
|
def _normalize(self, flist): |
|
return (flist - np.array(self.config.mean)) / (np.array(self.config.std) + 1e-8) |
|
def _init_weights(self): |
|
for module in self.modules(): |
|
if isinstance(module, nn.Linear): |
|
nn.init.xavier_normal_(module.weight) |
|
if module.bias is not None: |
|
nn.init.zeros_(module.bias) |
|
|
|
def forward(self, lx: list): |
|
|
|
out = list() |
|
for nx in lx: |
|
x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0).to(self.device) |
|
x = self.input_proj(x) |
|
x = x.unsqueeze(1) |
|
x = self.transformer(x) |
|
x = x.squeeze(1) |
|
out.append(self.head(x).item() > 0) |
|
return out |