SnowFlash383935's picture
Update model.py
e29ae22 verified
from transformers import PreTrainedModel
import torch.nn as nn
import numpy as np
from torch import tensor, float32
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from .model_config import TabularTransformerConfig
class FleshkaTabularTransformer(PreTrainedModel):
config_class = TabularTransformerConfig
def __init__(self, config):
super().__init__(config)
# Проекция входных данных
self.config = config
self.input_proj = nn.Linear(config.input_dim, config.d_model)
# Слои трансформера
encoder_layers = TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.nhead,
dim_feedforward=config.d_model * 4,
dropout=config.dropout,
activation="gelu",
batch_first=True
)
self.transformer = TransformerEncoder(encoder_layers, config.num_layers)
# Выходной слой
self.head = nn.Sequential(
nn.LayerNorm(config.d_model),
nn.Linear(config.d_model, 1)
)
self._init_weights()
def _normalize(self, flist):
return (flist - np.array(self.config.mean)) / (np.array(self.config.std) + 1e-8)
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, lx: list):
# x: [batch_size, input_dim]
out = list()
for nx in lx:
x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0).to(self.device)
x = self.input_proj(x) # [batch_size, d_model]
x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
x = self.transformer(x) # [batch_size, 1, d_model]
x = x.squeeze(1) # [batch_size, d_model]
out.append(self.head(x).item() > 0)
return out