Update model.py
Browse files
model.py
CHANGED
@@ -46,7 +46,7 @@ class FleshkaTabularTransformer(PreTrainedModel):
|
|
46 |
# x: [batch_size, input_dim]
|
47 |
out = list()
|
48 |
for nx in lx:
|
49 |
-
x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0)
|
50 |
x = self.input_proj(x) # [batch_size, d_model]
|
51 |
x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
|
52 |
x = self.transformer(x) # [batch_size, 1, d_model]
|
|
|
46 |
# x: [batch_size, input_dim]
|
47 |
out = list()
|
48 |
for nx in lx:
|
49 |
+
x = tensor(self._normalize(nx), dtype=float32).unsqueeze(0).to(self.device)
|
50 |
x = self.input_proj(x) # [batch_size, d_model]
|
51 |
x = x.unsqueeze(1) # [batch_size, 1, d_model] (добавляем seq_len=1)
|
52 |
x = self.transformer(x) # [batch_size, 1, d_model]
|