SnowFlash383935's picture
Update model.py
50af525 verified
from transformers import PreTrainedModel
import torch.nn as nn
import numpy as np
import torch
from torch import tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from .model_config import TabularTransformerConfig
class FleshkaTabularTransformer(PreTrainedModel):
config_class = TabularTransformerConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.input_proj = nn.Linear(config.input_dim, config.d_model)
encoder_layers = TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.nhead,
dim_feedforward=config.d_model * 4,
dropout=config.dropout,
activation="gelu",
batch_first=True
)
self.transformer = TransformerEncoder(encoder_layers, config.num_layers)
self.head = nn.Sequential(
nn.LayerNorm(config.d_model),
nn.Linear(config.d_model, 1)
)
self._init_weights()
def _normalize(self, flist):
return (flist - np.array(self.config.mean)) / (np.array(self.config.std) + 1e-8)
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.xavier_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, lx: list):
out = []
for nx in lx:
# Убедимся, что входные данные в float16
x = tensor(self._normalize(nx), dtype=self.input_proj.weight.dtype).unsqueeze(0).to(self.device)
x = self.input_proj(x)
x = x.unsqueeze(1)
x = self.transformer(x)
x = x.squeeze(1)
out.append(self.head(x).item() > 0)
return out