Update structformer.py
Browse files- structformer.py +3 -3
structformer.py
CHANGED
@@ -356,7 +356,7 @@ class Transformer(nn.Module):
|
|
356 |
self.pos_emb = nn.Embedding(500, hidden_size)
|
357 |
|
358 |
self.layers = nn.ModuleList([
|
359 |
-
|
360 |
dropatt=dropatt, relative_bias=relative_bias)
|
361 |
for _ in range(nlayers)])
|
362 |
|
@@ -474,12 +474,12 @@ class StructFormer(Transformer):
|
|
474 |
pad=pad)
|
475 |
|
476 |
self.parser_layers = nn.ModuleList([
|
477 |
-
nn.Sequential(
|
478 |
nn.LayerNorm(hidden_size, elementwise_affine=False),
|
479 |
nn.Tanh()) for i in range(n_parser_layers)])
|
480 |
|
481 |
self.distance_ff = nn.Sequential(
|
482 |
-
|
483 |
nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
|
484 |
nn.Linear(hidden_size, 1))
|
485 |
|
|
|
356 |
self.pos_emb = nn.Embedding(500, hidden_size)
|
357 |
|
358 |
self.layers = nn.ModuleList([
|
359 |
+
TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
|
360 |
dropatt=dropatt, relative_bias=relative_bias)
|
361 |
for _ in range(nlayers)])
|
362 |
|
|
|
474 |
pad=pad)
|
475 |
|
476 |
self.parser_layers = nn.ModuleList([
|
477 |
+
nn.Sequential(Conv1d(hidden_size, conv_size),
|
478 |
nn.LayerNorm(hidden_size, elementwise_affine=False),
|
479 |
nn.Tanh()) for i in range(n_parser_layers)])
|
480 |
|
481 |
self.distance_ff = nn.Sequential(
|
482 |
+
Conv1d(hidden_size, 2),
|
483 |
nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
|
484 |
nn.Linear(hidden_size, 1))
|
485 |
|