RichardWang commited on
Commit
6549220
1 Parent(s): aa6806f
Files changed (3) hide show
  1. config.json +4 -2
  2. modeling_tsp.py +76 -49
  3. pytorch_model.bin +2 -2
config.json CHANGED
@@ -1,6 +1,6 @@
1
  {
2
  "architectures": [
3
- "TSPModelForPretraining"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_tsp.TSPConfig",
@@ -11,6 +11,7 @@
11
  "AutoModelForTokenClassification": "modeling_tsp.TSPModelForTokenClassification"
12
  },
13
  "dropout_prob": 0.1,
 
14
  "embedding_size": 128,
15
  "hidden_size": 256,
16
  "intermediate_size": 1024,
@@ -21,6 +22,7 @@
21
  "pad_token_id": 0,
22
  "position_embedding_type": "absolute",
23
  "torch_dtype": "float32",
24
- "transformers_version": "4.17.0",
 
25
  "vocab_size": 30522
26
  }
 
1
  {
2
  "architectures": [
3
+ "TSPModelForPreTraining"
4
  ],
5
  "auto_map": {
6
  "AutoConfig": "configuration_tsp.TSPConfig",
 
11
  "AutoModelForTokenClassification": "modeling_tsp.TSPModelForTokenClassification"
12
  },
13
  "dropout_prob": 0.1,
14
+ "electra_generator_size_divisor": 4,
15
  "embedding_size": 128,
16
  "hidden_size": 256,
17
  "intermediate_size": 1024,
 
22
  "pad_token_id": 0,
23
  "position_embedding_type": "absolute",
24
  "torch_dtype": "float32",
25
+ "transformers_version": "4.19.0.dev0",
26
+ "use_electra": true,
27
  "vocab_size": 30522
28
  }
modeling_tsp.py CHANGED
@@ -9,12 +9,12 @@ import torch
9
  from torch import nn
10
  import torch.nn.functional as F
11
  from transformers import PreTrainedModel
12
- from .configuration_tsp import TSPConfig
13
 
14
 
15
  class TSPPreTrainedModel(PreTrainedModel):
16
  config_class = TSPConfig
17
- base_model_prefix = "tsp"
18
 
19
  def _init_weights(self, module):
20
  """Initialize the weights"""
@@ -32,20 +32,21 @@ class TSPPreTrainedModel(PreTrainedModel):
32
  module.bias.data.zero_()
33
  module.weight.data.fill_(1.0)
34
 
 
35
  # ====================================
36
  # Pretraining Model
37
  # ====================================
38
 
39
 
40
- class TSPModelForPretraining(TSPPreTrainedModel):
41
- def __init__(self, config, use_electra=False):
42
  super().__init__(config)
43
  self.backbone = TSPModel(config)
44
- if use_electra:
45
  mlm_config = deepcopy(config)
46
- mlm_config.hidden_size /= config.generator_size_divisor
47
- mlm_config.intermediate_size /= config.generator_size_divisor
48
- mlm_config.num_attention_heads /= config.generator_size_divisor
49
  self.mlm_backbone = TSPModel(mlm_config)
50
  self.mlm_head = MaskedLMHead(
51
  mlm_config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings
@@ -55,7 +56,10 @@ class TSPModelForPretraining(TSPPreTrainedModel):
55
  self.rtd_head = ReplacedTokenDiscriminationHead(config)
56
  else:
57
  self.mlm_backbone = self.backbone
58
- self.mlm_head = MaskedLMHead(config)
 
 
 
59
  self.apply(self._init_weights)
60
 
61
  def forward(self, *args, **kwargs):
@@ -63,40 +67,6 @@ class TSPModelForPretraining(TSPPreTrainedModel):
63
  "Refer to the implementation of text structrue prediction task for how to use the model."
64
  )
65
 
66
- def mlm_forward(
67
- self,
68
- corrupted_ids, # <int>(B,L), partially masked/replaced input token ids
69
- attention_mask, # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
70
- token_type_ids, # <int>(B,L), 0 / 1 corresponds to a segment A / B token
71
- mlm_selected=None, # <bool>(B,L), True at mlm selected positiosns. Calculate logits at mlm selected positions if not None.
72
- ):
73
- hidden_states = self.mlm_backbone(
74
- input_ids=corrupted_ids,
75
- attention_mask=attention_mask,
76
- token_type_ids=token_type_ids,
77
- ) # (B,L,D)
78
- return self.mlm_head(
79
- hidden_states, is_selected=mlm_selected
80
- ) # (#mlm selected, vocab size)/ (B,L,vocab size)
81
-
82
- def rtd_forward(
83
- self,
84
- corrupted_ids, # <int>(B,L), partially replaced input token ids
85
- attention_mask, # <int>(B,L), 1 / 0 for tokens that are not attended/ attended
86
- token_type_ids, # <int>(B,L), 0 / 1 corresponds to a segment A / B token
87
- ):
88
- hidden_states = self.rtd_backbone(
89
- input_ids=corrupted_ids,
90
- attention_mask=attention_mask,
91
- token_type_ids=token_type_ids,
92
- ) # (B,L,D)
93
- return self.rtd_backbone(hidden_states) # (B,L)
94
-
95
- def tsp_forward(
96
- self, hidden_states, # (B,L,D)
97
- ):
98
- raise NotImplementedError
99
-
100
 
101
  class MaskedLMHead(nn.Module):
102
  def __init__(self, config, word_embeddings=None):
@@ -135,6 +105,22 @@ class ReplacedTokenDiscriminationHead(nn.Module):
135
  return x.squeeze(-1) # (B,L)
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # ====================================
139
  # Finetuning Model
140
  # ====================================
@@ -164,8 +150,8 @@ class TSPModelForTokenClassification(TSPPreTrainedModel):
164
  class TokenClassificationHead(nn.Module):
165
  def __init__(self, config, num_classes):
166
  super().__init__()
167
- self.dropout = nn.Dropout(c.dropout_prob)
168
- self.classifier = nn.Linear(c.hidden_size, num_classes)
169
 
170
  def forward(self, x): # (B,L,D)
171
  x = self.dropout(x) # (B,L,D)
@@ -213,6 +199,7 @@ class TSPModelForQuestionAnswering(TSPPreTrainedModel):
213
  super().__init__()
214
  self.backbone = TSPModel(config)
215
  self.head = SequenceClassififcationHead(config, num_classes)
 
216
 
217
  def forward(
218
  self,
@@ -345,9 +332,6 @@ class SquadHead(nn.Module):
345
 
346
 
347
  class TSPModel(TSPPreTrainedModel):
348
- config_class = TSPConfig
349
- base_model_prefix = "tsp"
350
-
351
  def __init__(self, config):
352
  super().__init__(config)
353
  self.embeddings = Embeddings(config)
@@ -405,9 +389,9 @@ class Embeddings(nn.Module):
405
  ):
406
  B, L = input_ids.shape
407
  embeddings = self.word_embeddings(input_ids) # (B,L,E)
 
408
  if hasattr(self, "position_embeddings"):
409
  embeddings += self.position_embeddings.weight[None, :L, :]
410
- embeddings += self.token_type_embeddings(token_type_ids)
411
  embeddings = self.norm(embeddings) # (B,L,E)
412
  embeddings = self.dropout(embeddings) # (B,L,E)
413
  return embeddings # (B,L,E)
@@ -453,6 +437,8 @@ class MultiHeadSelfAttention(nn.Module):
453
  self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
454
  self.H = config.num_attention_heads
455
  self.d = config.hidden_size // self.H
 
 
456
 
457
  def forward(
458
  self,
@@ -463,6 +449,8 @@ class MultiHeadSelfAttention(nn.Module):
463
  query, key, value = (
464
  self.mix_proj(x).view(B, L, H, 3 * d).transpose(1, 2).split(d, dim=-1)
465
  ) # (B,H,L,d),(B,H,L,d),(B,H,L,d)
 
 
466
  output = self.attention(query, key, value, attention_mask) # (B,H,L,d)
467
  output = self.o_proj(output.transpose(1, 2).reshape(B, L, D)) # (B,L,D)
468
  return output # (B,L,D)
@@ -503,4 +491,43 @@ class FeedForwardNetwork(nn.Module):
503
  return x # (B,L,D)
504
 
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
 
 
 
 
 
 
9
  from torch import nn
10
  import torch.nn.functional as F
11
  from transformers import PreTrainedModel
12
+ from configuration_tsp import TSPConfig
13
 
14
 
15
  class TSPPreTrainedModel(PreTrainedModel):
16
  config_class = TSPConfig
17
+ base_model_prefix = "backbone"
18
 
19
  def _init_weights(self, module):
20
  """Initialize the weights"""
 
32
  module.bias.data.zero_()
33
  module.weight.data.fill_(1.0)
34
 
35
+
36
  # ====================================
37
  # Pretraining Model
38
  # ====================================
39
 
40
 
41
+ class TSPModelForPreTraining(TSPPreTrainedModel):
42
+ def __init__(self, config):
43
  super().__init__(config)
44
  self.backbone = TSPModel(config)
45
+ if config.use_electra:
46
  mlm_config = deepcopy(config)
47
+ mlm_config.hidden_size //= config.electra_generator_size_divisor
48
+ mlm_config.intermediate_size //= config.electra_generator_size_divisor
49
+ mlm_config.num_attention_heads //= config.electra_generator_size_divisor
50
  self.mlm_backbone = TSPModel(mlm_config)
51
  self.mlm_head = MaskedLMHead(
52
  mlm_config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings
 
56
  self.rtd_head = ReplacedTokenDiscriminationHead(config)
57
  else:
58
  self.mlm_backbone = self.backbone
59
+ self.mlm_head = MaskedLMHead(
60
+ config, word_embeddings=self.mlm_backbone.embeddings.word_embeddings
61
+ )
62
+ self.tsp_head = TextStructurePredictionHead(config)
63
  self.apply(self._init_weights)
64
 
65
  def forward(self, *args, **kwargs):
 
67
  "Refer to the implementation of text structrue prediction task for how to use the model."
68
  )
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  class MaskedLMHead(nn.Module):
72
  def __init__(self, config, word_embeddings=None):
 
105
  return x.squeeze(-1) # (B,L)
106
 
107
 
108
+ class TextStructurePredictionHead(nn.Module):
109
+ def __init__(self, config):
110
+ super().__init__()
111
+ self.linear1 = nn.Linear(config.hidden_size * 2, config.hidden_size * 2)
112
+ self.norm = nn.LayerNorm(config.hidden_size * 2)
113
+ self.linear2 = nn.Linear(config.hidden_size * 2, 6)
114
+
115
+ def forward(
116
+ self, x, # (...,2D)
117
+ ):
118
+ x = self.linear1(x) # (...,2D)
119
+ x = F.gelu(x) # (...,2D)
120
+ x = self.norm(x) # (...,2D)
121
+ return self.linear2(x) # (...,C)
122
+
123
+
124
  # ====================================
125
  # Finetuning Model
126
  # ====================================
 
150
  class TokenClassificationHead(nn.Module):
151
  def __init__(self, config, num_classes):
152
  super().__init__()
153
+ self.dropout = nn.Dropout(config.dropout_prob)
154
+ self.classifier = nn.Linear(config.hidden_size, num_classes)
155
 
156
  def forward(self, x): # (B,L,D)
157
  x = self.dropout(x) # (B,L,D)
 
199
  super().__init__()
200
  self.backbone = TSPModel(config)
201
  self.head = SequenceClassififcationHead(config, num_classes)
202
+ self.apply(self._init_weights)
203
 
204
  def forward(
205
  self,
 
332
 
333
 
334
  class TSPModel(TSPPreTrainedModel):
 
 
 
335
  def __init__(self, config):
336
  super().__init__(config)
337
  self.embeddings = Embeddings(config)
 
389
  ):
390
  B, L = input_ids.shape
391
  embeddings = self.word_embeddings(input_ids) # (B,L,E)
392
+ embeddings += self.token_type_embeddings(token_type_ids)
393
  if hasattr(self, "position_embeddings"):
394
  embeddings += self.position_embeddings.weight[None, :L, :]
 
395
  embeddings = self.norm(embeddings) # (B,L,E)
396
  embeddings = self.dropout(embeddings) # (B,L,E)
397
  return embeddings # (B,L,E)
 
437
  self.o_proj = nn.Linear(config.hidden_size, config.hidden_size)
438
  self.H = config.num_attention_heads
439
  self.d = config.hidden_size // self.H
440
+ if config.position_embedding_type == "rotary":
441
+ self.rotray_position_embeds = RotaryEmbedding(self.d)
442
 
443
  def forward(
444
  self,
 
449
  query, key, value = (
450
  self.mix_proj(x).view(B, L, H, 3 * d).transpose(1, 2).split(d, dim=-1)
451
  ) # (B,H,L,d),(B,H,L,d),(B,H,L,d)
452
+ if hasattr(self, "rotray_position_embeds"):
453
+ query, key = self.rotray_position_embeds(query, key)
454
  output = self.attention(query, key, value, attention_mask) # (B,H,L,d)
455
  output = self.o_proj(output.transpose(1, 2).reshape(B, L, D)) # (B,L,D)
456
  return output # (B,L,D)
 
491
  return x # (B,L,D)
492
 
493
 
494
+ class RotaryEmbedding(nn.Module):
495
+ seq_len_cached = 0
496
+ cos_cached = None
497
+ sin_cached = None
498
+
499
+ def __init__(self, dim):
500
+ super().__init__()
501
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
502
+ self.register_buffer("inv_freq", inv_freq)
503
+
504
+ def _forward(self, x): # (B,H,L,d)
505
+ # Get rotary embeddings on the fly
506
+ ## create
507
+ seq_len = x.shape[2]
508
+ if seq_len > RotaryEmbedding.seq_len_cached:
509
+ RotaryEmbedding.seq_len_cached = seq_len
510
+ t = torch.arange(seq_len, device=x.device, dtype=self.inv_freq.dtype)
511
+ freqs = t.view(-1, 1) @ self.inv_freq.view(1, -1)
512
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device) # (L,d)
513
+ RotaryEmbedding.cos_cached = emb.cos()[None, None, :, :]
514
+ RotaryEmbedding.sin_cached = emb.sin()[None, None, :, :]
515
+ ## take
516
+ if seq_len == RotaryEmbedding.seq_len_cached:
517
+ cos, sin = RotaryEmbedding.cos_cached, RotaryEmbedding.sin_cached
518
+ else:
519
+ cos, sin = (
520
+ RotaryEmbedding.cos_cached[:, :, :seq_len, :], # (1,1,L,d)
521
+ RotaryEmbedding.sin_cached[:, :, :seq_len, :], # (1,1,L,d)
522
+ )
523
+
524
+ # Apply rotary embeddings
525
+ sections = [x.shape[-1] // 2, x.shape[-1] - x.shape[-1] // 2]
526
+ x1, x2 = x.split(sections, dim=-1)
527
+ half_rotated_x = torch.cat((-x2, x1), dim=-1)
528
+ return (x * cos) + (half_rotated_x * sin)
529
 
530
+ def forward(
531
+ self, query, key, # (B,H,L,d) # (B,H,L,d)
532
+ ):
533
+ return self._forward(query), self._forward(key)
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9bad83b6706d009b6c0b06cd6d42664b8c67c22f2d3215f88b09f83c9eb93604
3
- size 69713927
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c8decb4de84befc5103d4b4b7c9ed0d61fc598ad859c30163e92107f76ea731
3
+ size 57777425