DigitalEduTransformers / model_config.py
SnowFlash383935's picture
Update model_config.py
2af1c2d verified
from transformers import PretrainedConfig
class TabularTransformerConfig(PretrainedConfig):
model_type = "transformer"
def __init__(
self,
mean=None,
std=None,
input_dim=7, # Размер входных признаков
d_model=512, # Размер скрытого слоя
nhead=16, # Количество голов внимания
num_layers=32, # Количество слоев трансформера
dropout=0.2, # Dropout
**kwargs
):
super().__init__(**kwargs)
self.mean = mean
self.std = std
self.input_dim = input_dim
self.d_model = d_model
self.nhead = nhead
self.num_layers = num_layers
self.dropout = dropout