omarmomen commited on
Commit
0fbf7fe
·
1 Parent(s): 02992f5
Files changed (3) hide show
  1. config.json +29 -0
  2. pytorch_model.bin +3 -0
  3. structformer_as_hf_no_parser.py +730 -0
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StructformerModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "structformer_as_hf_no_parser.StructformerConfig",
7
+ "AutoModelForMaskedLM": "structformer_as_hf_no_parser.StructformerModel"
8
+ },
9
+ "conv_size": 9,
10
+ "dropatt": 0.1,
11
+ "dropout": 0.1,
12
+ "hidden_size": 768,
13
+ "model_type": "structformer",
14
+ "n_context_layers": 0,
15
+ "n_parser_layers": 0,
16
+ "nhead": 12,
17
+ "nlayers": 12,
18
+ "ntokens": 32000,
19
+ "pad": 0,
20
+ "pos_emb": true,
21
+ "relations": [
22
+ "head",
23
+ "child"
24
+ ],
25
+ "relative_bias": false,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.18.0",
28
+ "weight_act": "softmax"
29
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bfc5195a3f7885f675085f99e87edcff14fc97e2d8e249aa1c599b9a47113fa
3
+ size 440252407
structformer_as_hf_no_parser.py ADDED
@@ -0,0 +1,730 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from transformers import PreTrainedModel
6
+ from transformers import PretrainedConfig
7
+ from transformers.modeling_outputs import MaskedLMOutput
8
+ from typing import List
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ MaskedLMOutput,
14
+ SequenceClassifierOutput
15
+ )
16
+
17
+ ##########################################
18
+ # HuggingFace Config
19
+ ##########################################
20
+ class StructformerConfig(PretrainedConfig):
21
+ model_type = "structformer"
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size=768,
26
+ n_context_layers=2,
27
+ nlayers=6,
28
+ ntokens=32000,
29
+ nhead=8,
30
+ dropout=0.1,
31
+ dropatt=0.1,
32
+ relative_bias=False,
33
+ pos_emb=False,
34
+ pad=0,
35
+ n_parser_layers=4,
36
+ conv_size=9,
37
+ relations=('head', 'child'),
38
+ weight_act='softmax',
39
+ **kwargs,
40
+ ):
41
+ self.hidden_size = hidden_size
42
+ self.n_context_layers = n_context_layers
43
+ self.nlayers = nlayers
44
+ self.ntokens = ntokens
45
+ self.nhead = nhead
46
+ self.dropout = dropout
47
+ self.dropatt = dropatt
48
+ self.relative_bias = relative_bias
49
+ self.pos_emb = pos_emb
50
+ self.pad = pad
51
+ self.n_parser_layers = n_parser_layers
52
+ self.conv_size = conv_size
53
+ self.relations = relations
54
+ self.weight_act = weight_act
55
+ super().__init__(**kwargs)
56
+
57
+ ##########################################
58
+ # Custom Layers
59
+ ##########################################
60
+ def _get_activation_fn(activation):
61
+ """Get specified activation function."""
62
+ if activation == "relu":
63
+ return nn.ReLU()
64
+ elif activation == "gelu":
65
+ return nn.GELU()
66
+ elif activation == "leakyrelu":
67
+ return nn.LeakyReLU()
68
+
69
+ raise RuntimeError(
70
+ "activation should be relu/gelu, not {}".format(activation))
71
+
72
+ class Conv1d(nn.Module):
73
+ """1D convolution layer."""
74
+
75
+ def __init__(self, hidden_size, kernel_size, dilation=1):
76
+ """Initialization.
77
+
78
+ Args:
79
+ hidden_size: dimension of input embeddings
80
+ kernel_size: convolution kernel size
81
+ dilation: the spacing between the kernel points
82
+ """
83
+ super(Conv1d, self).__init__()
84
+
85
+ if kernel_size % 2 == 0:
86
+ padding = (kernel_size // 2) * dilation
87
+ self.shift = True
88
+ else:
89
+ padding = ((kernel_size - 1) // 2) * dilation
90
+ self.shift = False
91
+ self.conv = nn.Conv1d(
92
+ hidden_size,
93
+ hidden_size,
94
+ kernel_size,
95
+ padding=padding,
96
+ dilation=dilation)
97
+
98
+ def forward(self, x):
99
+ """Compute convolution.
100
+
101
+ Args:
102
+ x: input embeddings
103
+ Returns:
104
+ conv_output: convolution results
105
+ """
106
+
107
+ if self.shift:
108
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
109
+ else:
110
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
111
+
112
+ class MultiheadAttention(nn.Module):
113
+ """Multi-head self-attention layer."""
114
+
115
+ def __init__(self,
116
+ embed_dim,
117
+ num_heads,
118
+ dropout=0.,
119
+ bias=True,
120
+ v_proj=True,
121
+ out_proj=True,
122
+ relative_bias=True):
123
+ """Initialization.
124
+
125
+ Args:
126
+ embed_dim: dimension of input embeddings
127
+ num_heads: number of self-attention heads
128
+ dropout: dropout rate
129
+ bias: bool, indicate whether include bias for linear transformations
130
+ v_proj: bool, indicate whether project inputs to new values
131
+ out_proj: bool, indicate whether project outputs to new values
132
+ relative_bias: bool, indicate whether use a relative position based
133
+ attention bias
134
+ """
135
+
136
+ super(MultiheadAttention, self).__init__()
137
+ self.embed_dim = embed_dim
138
+
139
+ self.num_heads = num_heads
140
+ self.drop = nn.Dropout(dropout)
141
+ self.head_dim = embed_dim // num_heads
142
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
143
+ "divisible by "
144
+ "num_heads")
145
+
146
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
147
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
148
+ if v_proj:
149
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
150
+ else:
151
+ self.v_proj = nn.Identity()
152
+
153
+ if out_proj:
154
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
155
+ else:
156
+ self.out_proj = nn.Identity()
157
+
158
+ if relative_bias:
159
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
160
+ else:
161
+ self.relative_bias = None
162
+
163
+ self._reset_parameters()
164
+
165
+ def _reset_parameters(self):
166
+ """Initialize attention parameters."""
167
+
168
+ init.xavier_uniform_(self.q_proj.weight)
169
+ init.constant_(self.q_proj.bias, 0.)
170
+
171
+ init.xavier_uniform_(self.k_proj.weight)
172
+ init.constant_(self.k_proj.bias, 0.)
173
+
174
+ if isinstance(self.v_proj, nn.Linear):
175
+ init.xavier_uniform_(self.v_proj.weight)
176
+ init.constant_(self.v_proj.bias, 0.)
177
+
178
+ if isinstance(self.out_proj, nn.Linear):
179
+ init.xavier_uniform_(self.out_proj.weight)
180
+ init.constant_(self.out_proj.bias, 0.)
181
+
182
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
183
+ """Compute multi-head self-attention.
184
+
185
+ Args:
186
+ query: input embeddings
187
+ key_padding_mask: 3D mask that prevents attention to certain positions
188
+ attn_mask: 3D mask that rescale the attention weight at each position
189
+ Returns:
190
+ attn_output: self-attention output
191
+ """
192
+
193
+ length, bsz, embed_dim = query.size()
194
+ assert embed_dim == self.embed_dim
195
+
196
+ head_dim = embed_dim // self.num_heads
197
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
198
+ "divisible by num_heads")
199
+ scaling = float(head_dim)**-0.5
200
+
201
+ q = self.q_proj(query)
202
+ k = self.k_proj(query)
203
+ v = self.v_proj(query)
204
+
205
+ q = q * scaling
206
+
207
+ if attn_mask is not None:
208
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
209
+ query.size(0), query.size(0)]
210
+
211
+ q = q.contiguous().view(length, bsz * self.num_heads,
212
+ head_dim).transpose(0, 1)
213
+ k = k.contiguous().view(length, bsz * self.num_heads,
214
+ head_dim).transpose(0, 1)
215
+ v = v.contiguous().view(length, bsz * self.num_heads,
216
+ head_dim).transpose(0, 1)
217
+
218
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
219
+ assert list(
220
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
221
+
222
+ if self.relative_bias is not None:
223
+ pos = torch.arange(length, device=query.device)
224
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
225
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
226
+ -1)
227
+
228
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
229
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
230
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
231
+ attn_output_weights = attn_output_weights + relative_bias
232
+
233
+ if key_padding_mask is not None:
234
+ attn_output_weights = attn_output_weights + key_padding_mask
235
+
236
+ if attn_mask is None:
237
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
238
+ else:
239
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
240
+
241
+ attn_output_weights = self.drop(attn_output_weights)
242
+
243
+ attn_output = torch.bmm(attn_output_weights, v)
244
+
245
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
246
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
247
+ length, bsz, embed_dim)
248
+ attn_output = self.out_proj(attn_output)
249
+
250
+ return attn_output
251
+
252
+ class TransformerLayer(nn.Module):
253
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
254
+
255
+ def __init__(self,
256
+ d_model,
257
+ nhead,
258
+ dim_feedforward=2048,
259
+ dropout=0.1,
260
+ dropatt=0.1,
261
+ activation="leakyrelu",
262
+ relative_bias=True):
263
+ """Initialization.
264
+
265
+ Args:
266
+ d_model: dimension of inputs
267
+ nhead: number of self-attention heads
268
+ dim_feedforward: dimension of hidden layer in feedforward layer
269
+ dropout: dropout rate
270
+ dropatt: drop attention rate
271
+ activation: activation function
272
+ relative_bias: bool, indicate whether use a relative position based
273
+ attention bias
274
+ """
275
+
276
+ super(TransformerLayer, self).__init__()
277
+
278
+ self.self_attn = MultiheadAttention(
279
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
280
+
281
+ # Implementation of Feedforward model
282
+ self.feedforward = nn.Sequential(
283
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
284
+ _get_activation_fn(activation), nn.Dropout(dropout),
285
+ nn.Linear(dim_feedforward, d_model))
286
+
287
+ self.norm = nn.LayerNorm(d_model)
288
+ self.dropout1 = nn.Dropout(dropout)
289
+ self.dropout2 = nn.Dropout(dropout)
290
+
291
+ self.nhead = nhead
292
+
293
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
294
+ """Pass the input through the encoder layer.
295
+
296
+ Args:
297
+ src: the sequence to the encoder layer (required).
298
+ attn_mask: the mask for the src sequence (optional).
299
+ key_padding_mask: the mask for the src keys per batch (optional).
300
+ Returns:
301
+ src3: the output of transformer layer, share the same shape as src.
302
+ """
303
+ src2 = self.self_attn(
304
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
305
+ src2 = src + self.dropout1(src2)
306
+ src3 = self.feedforward(src2)
307
+ src3 = src2 + self.dropout2(src3)
308
+
309
+ return src3
310
+
311
+ ##########################################
312
+ # Custom Models
313
+ ##########################################
314
+ def cumprod(x, reverse=False, exclusive=False):
315
+ """cumulative product."""
316
+ if reverse:
317
+ x = x.flip([-1])
318
+
319
+ if exclusive:
320
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
321
+
322
+ cx = x.cumprod(-1)
323
+
324
+ if reverse:
325
+ cx = cx.flip([-1])
326
+ return cx
327
+
328
+ def cumsum(x, reverse=False, exclusive=False):
329
+ """cumulative sum."""
330
+ bsz, _, length = x.size()
331
+ device = x.device
332
+ if reverse:
333
+ if exclusive:
334
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
335
+ else:
336
+ w = torch.ones([bsz, length, length], device=device).tril(0)
337
+ cx = torch.bmm(x, w)
338
+ else:
339
+ if exclusive:
340
+ w = torch.ones([bsz, length, length], device=device).triu(1)
341
+ else:
342
+ w = torch.ones([bsz, length, length], device=device).triu(0)
343
+ cx = torch.bmm(x, w)
344
+ return cx
345
+
346
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
347
+ """cumulative min."""
348
+ if reverse:
349
+ if exclusive:
350
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
351
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
352
+ else:
353
+ if exclusive:
354
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
355
+ x = x.cummin(-1)[0]
356
+ return x
357
+
358
+ class Transformer(nn.Module):
359
+ """Transformer model."""
360
+
361
+ def __init__(self,
362
+ hidden_size,
363
+ nlayers,
364
+ ntokens,
365
+ nhead=8,
366
+ dropout=0.1,
367
+ dropatt=0.1,
368
+ relative_bias=True,
369
+ pos_emb=False,
370
+ pad=0):
371
+ """Initialization.
372
+
373
+ Args:
374
+ hidden_size: dimension of inputs and hidden states
375
+ nlayers: number of layers
376
+ ntokens: number of output categories
377
+ nhead: number of self-attention heads
378
+ dropout: dropout rate
379
+ dropatt: drop attention rate
380
+ relative_bias: bool, indicate whether use a relative position based
381
+ attention bias
382
+ pos_emb: bool, indicate whether use a learnable positional embedding
383
+ pad: pad token index
384
+ """
385
+
386
+ super(Transformer, self).__init__()
387
+
388
+ self.drop = nn.Dropout(dropout)
389
+
390
+ self.emb = nn.Embedding(ntokens, hidden_size)
391
+ if pos_emb:
392
+ self.pos_emb = nn.Embedding(500, hidden_size)
393
+
394
+ self.layers = nn.ModuleList([
395
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
396
+ dropatt=dropatt, relative_bias=relative_bias)
397
+ for _ in range(nlayers)])
398
+
399
+ self.norm = nn.LayerNorm(hidden_size)
400
+
401
+ self.output_layer = nn.Linear(hidden_size, ntokens)
402
+ self.output_layer.weight = self.emb.weight
403
+
404
+ self.init_weights()
405
+
406
+ self.nlayers = nlayers
407
+ self.nhead = nhead
408
+ self.ntokens = ntokens
409
+ self.hidden_size = hidden_size
410
+ self.pad = pad
411
+
412
+ def init_weights(self):
413
+ """Initialize token embedding and output bias."""
414
+ initrange = 0.1
415
+ self.emb.weight.data.uniform_(-initrange, initrange)
416
+ if hasattr(self, 'pos_emb'):
417
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
418
+ self.output_layer.bias.data.fill_(0)
419
+
420
+ def visibility(self, x, device):
421
+ """Mask pad tokens."""
422
+ visibility = (x != self.pad).float()
423
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
424
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
425
+ return visibility.log()
426
+
427
+ def encode(self, x, pos):
428
+ """Standard transformer encode process."""
429
+ h = self.emb(x)
430
+ if hasattr(self, 'pos_emb'):
431
+ h = h + self.pos_emb(pos)
432
+ h_list = []
433
+ visibility = self.visibility(x, x.device)
434
+
435
+ for i in range(self.nlayers):
436
+ h_list.append(h)
437
+ h = self.layers[i](
438
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
439
+
440
+ output = h
441
+ h_array = torch.stack(h_list, dim=2)
442
+
443
+ return output, h_array
444
+
445
+ def forward(self, x, pos):
446
+ """Pass the input through the encoder layer.
447
+
448
+ Args:
449
+ x: input tokens (required).
450
+ pos: position for each token (optional).
451
+ Returns:
452
+ output: probability distributions for missing tokens.
453
+ state_dict: parsing results and raw output
454
+ """
455
+
456
+ batch_size, length = x.size()
457
+
458
+ raw_output, _ = self.encode(x, pos)
459
+ raw_output = self.norm(raw_output)
460
+ raw_output = self.drop(raw_output)
461
+
462
+ output = self.output_layer(raw_output)
463
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
464
+
465
+ class StructFormer(Transformer):
466
+ """StructFormer model."""
467
+
468
+ def __init__(self,
469
+ hidden_size,
470
+ n_context_layers,
471
+ nlayers,
472
+ ntokens,
473
+ nhead=8,
474
+ dropout=0.1,
475
+ dropatt=0.1,
476
+ relative_bias=False,
477
+ pos_emb=False,
478
+ pad=0,
479
+ n_parser_layers=4,
480
+ conv_size=9,
481
+ relations=('head', 'child'),
482
+ weight_act='softmax'):
483
+ """Initialization.
484
+
485
+ Args:
486
+ hidden_size: dimension of inputs and hidden states
487
+ nlayers: number of layers
488
+ ntokens: number of output categories
489
+ nhead: number of self-attention heads
490
+ dropout: dropout rate
491
+ dropatt: drop attention rate
492
+ relative_bias: bool, indicate whether use a relative position based
493
+ attention bias
494
+ pos_emb: bool, indicate whether use a learnable positional embedding
495
+ pad: pad token index
496
+ n_parser_layers: number of parsing layers
497
+ conv_size: convolution kernel size for parser
498
+ relations: relations that are used to compute self attention
499
+ weight_act: relations distribution activation function
500
+ """
501
+
502
+ super(StructFormer, self).__init__(
503
+ hidden_size,
504
+ nlayers,
505
+ ntokens,
506
+ nhead=nhead,
507
+ dropout=dropout,
508
+ dropatt=dropatt,
509
+ relative_bias=relative_bias,
510
+ pos_emb=pos_emb,
511
+ pad=pad)
512
+
513
+
514
+ def encode(self, x, pos):
515
+ h = self.emb(x)
516
+ if hasattr(self, 'pos_emb'):
517
+ h = h + self.pos_emb(pos)
518
+ h_list = []
519
+ visibility = self.visibility(x, x.device)
520
+
521
+ for i in range(self.nlayers):
522
+ h_list.append(h)
523
+ h = self.layers[i](
524
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
525
+
526
+ output = h
527
+ h_array = torch.stack(h_list, dim=2)
528
+
529
+ return output
530
+
531
+
532
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
533
+
534
+ x = input_ids
535
+ batch_size, length = x.size()
536
+
537
+ if position_ids is None:
538
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
539
+
540
+ raw_output = self.encode(x, pos)
541
+ raw_output = self.norm(raw_output)
542
+ raw_output = self.drop(raw_output)
543
+
544
+ output = self.output_layer(raw_output)
545
+
546
+ loss = None
547
+ if labels is not None:
548
+ loss_fct = nn.CrossEntropyLoss()
549
+ loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1))
550
+
551
+ return MaskedLMOutput(
552
+ loss=loss, # shape: 1
553
+ logits=output, # shape: (batch_size * length, ntokens)
554
+ hidden_states=None,
555
+ attentions=None,
556
+ )
557
+
558
+ ##########################################
559
+ # HuggingFace Model
560
+ ##########################################
561
+ class StructformerModel(PreTrainedModel):
562
+ config_class = StructformerConfig
563
+
564
+ def __init__(self, config):
565
+ super().__init__(config)
566
+ self.model = StructFormer(
567
+ hidden_size=config.hidden_size,
568
+ n_context_layers=config.n_context_layers,
569
+ nlayers=config.nlayers,
570
+ ntokens=config.ntokens,
571
+ nhead=config.nhead,
572
+ dropout=config.dropout,
573
+ dropatt=config.dropatt,
574
+ relative_bias=config.relative_bias,
575
+ pos_emb=config.pos_emb,
576
+ pad=config.pad,
577
+ n_parser_layers=config.n_parser_layers,
578
+ conv_size=config.conv_size,
579
+ relations=config.relations,
580
+ weight_act=config.weight_act
581
+ )
582
+
583
+ def forward(self, input_ids, labels=None, **kwargs):
584
+ return self.model(input_ids, labels=labels, **kwargs)
585
+
586
+
587
+ class StructFormerClassification(Transformer):
588
+ """StructFormer model."""
589
+
590
+ def __init__(self,
591
+ hidden_size,
592
+ n_context_layers,
593
+ nlayers,
594
+ ntokens,
595
+ nhead=8,
596
+ dropout=0.1,
597
+ dropatt=0.1,
598
+ relative_bias=False,
599
+ pos_emb=False,
600
+ pad=0,
601
+ n_parser_layers=4,
602
+ conv_size=9,
603
+ relations=('head', 'child'),
604
+ weight_act='softmax',
605
+ config=None,
606
+ ):
607
+
608
+
609
+ super(StructFormerClassification, self).__init__(
610
+ hidden_size,
611
+ nlayers,
612
+ ntokens,
613
+ nhead=nhead,
614
+ dropout=dropout,
615
+ dropatt=dropatt,
616
+ relative_bias=relative_bias,
617
+ pos_emb=pos_emb,
618
+ pad=pad)
619
+
620
+ self.num_labels = config.num_labels
621
+ self.config = config
622
+
623
+ self.classifier = RobertaClassificationHead(config)
624
+
625
+ def encode(self, x, pos):
626
+ h = self.emb(x)
627
+ if hasattr(self, 'pos_emb'):
628
+ h = h + self.pos_emb(pos)
629
+ h_list = []
630
+ visibility = self.visibility(x, x.device)
631
+
632
+ for i in range(self.nlayers):
633
+ h_list.append(h)
634
+ h = self.layers[i](
635
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
636
+
637
+ output = h
638
+ h_array = torch.stack(h_list, dim=2)
639
+
640
+ return output
641
+
642
+
643
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
644
+
645
+ x = input_ids
646
+ batch_size, length = x.size()
647
+
648
+ if position_ids is None:
649
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
650
+
651
+ raw_output = self.encode(x, pos)
652
+ raw_output = self.norm(raw_output)
653
+ raw_output = self.drop(raw_output)
654
+
655
+ #output = self.output_layer(raw_output)
656
+ logits = self.classifier(raw_output)
657
+
658
+ loss = None
659
+ if labels is not None:
660
+ if self.config.problem_type is None:
661
+ if self.num_labels == 1:
662
+ self.config.problem_type = "regression"
663
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
664
+ self.config.problem_type = "single_label_classification"
665
+ else:
666
+ self.config.problem_type = "multi_label_classification"
667
+
668
+ if self.config.problem_type == "regression":
669
+ loss_fct = MSELoss()
670
+ if self.num_labels == 1:
671
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
672
+ else:
673
+ loss = loss_fct(logits, labels)
674
+ elif self.config.problem_type == "single_label_classification":
675
+ loss_fct = CrossEntropyLoss()
676
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
677
+ elif self.config.problem_type == "multi_label_classification":
678
+ loss_fct = BCEWithLogitsLoss()
679
+ loss = loss_fct(logits, labels)
680
+
681
+
682
+ return SequenceClassifierOutput(
683
+ loss=loss,
684
+ logits=logits,
685
+ hidden_states=None,
686
+ attentions=None,
687
+ )
688
+
689
+
690
+ class StructformerModelForSequenceClassification(PreTrainedModel):
691
+ config_class = StructformerConfig
692
+ def __init__(self, config):
693
+ super().__init__(config)
694
+ self.model = StructFormerClassification(
695
+ hidden_size=config.hidden_size,
696
+ n_context_layers=config.n_context_layers,
697
+ nlayers=config.nlayers,
698
+ ntokens=config.ntokens,
699
+ nhead=config.nhead,
700
+ dropout=config.dropout,
701
+ dropatt=config.dropatt,
702
+ relative_bias=config.relative_bias,
703
+ pos_emb=config.pos_emb,
704
+ pad=config.pad,
705
+ n_parser_layers=config.n_parser_layers,
706
+ conv_size=config.conv_size,
707
+ relations=config.relations,
708
+ weight_act=config.weight_act,
709
+ config=config)
710
+
711
+ def _init_weights(self, module):
712
+ """Initialize the weights"""
713
+ if isinstance(module, nn.Linear):
714
+ # Slightly different from the TF version which uses truncated_normal for initialization
715
+ # cf https://github.com/pytorch/pytorch/pull/5617
716
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
717
+ if module.bias is not None:
718
+ module.bias.data.zero_()
719
+ elif isinstance(module, nn.Embedding):
720
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
721
+ if module.padding_idx is not None:
722
+ module.weight.data[module.padding_idx].zero_()
723
+ elif isinstance(module, nn.LayerNorm):
724
+ if module.bias is not None:
725
+ module.bias.data.zero_()
726
+ module.weight.data.fill_(1.0)
727
+
728
+
729
+ def forward(self, input_ids, labels=None, **kwargs):
730
+ return self.model(input_ids, labels=labels, **kwargs)