Martijn van Beers commited on
Commit
f59e918
·
1 Parent(s): 64ac833

Remove files that shouldn't have been committed

Browse files
Files changed (2) hide show
  1. lib/BERTalt.py +0 -551
  2. lib/roberta2.py.rej +0 -63
lib/BERTalt.py DELETED
@@ -1,551 +0,0 @@
1
- from __future__ import absolute_import
2
-
3
- import torch
4
- from torch import nn
5
- import torch.nn.functional as F
6
- import math
7
- from BERT_explainability.modules.layers_ours import *
8
-
9
- import transformers
10
-
11
- from transformers import BertConfig
12
- from transformers.modeling_outputs import BaseModelOutputWithPooling, BaseModelOutput
13
- from transformers import (
14
- BertPreTrainedModel,
15
- PreTrainedModel,
16
- )
17
-
18
-
19
- ACT2FN = {
20
- "relu": ReLU,
21
- "tanh": Tanh,
22
- "gelu": GELU,
23
- }
24
-
25
-
26
- def get_activation(activation_string):
27
- if activation_string in ACT2FN:
28
- return ACT2FN[activation_string]
29
- else:
30
- raise KeyError("function {} not found in ACT2FN mapping {}".format(activation_string, list(ACT2FN.keys())))
31
-
32
- def compute_rollout_attention(all_layer_matrices, start_layer=0):
33
- # adding residual consideration
34
- num_tokens = all_layer_matrices[0].shape[1]
35
- batch_size = all_layer_matrices[0].shape[0]
36
- eye = torch.eye(num_tokens).expand(batch_size, num_tokens, num_tokens).to(all_layer_matrices[0].device)
37
- all_layer_matrices = [all_layer_matrices[i] + eye for i in range(len(all_layer_matrices))]
38
- all_layer_matrices = [all_layer_matrices[i] / all_layer_matrices[i].sum(dim=-1, keepdim=True)
39
- for i in range(len(all_layer_matrices))]
40
- joint_attention = all_layer_matrices[start_layer]
41
- for i in range(start_layer+1, len(all_layer_matrices)):
42
- joint_attention = all_layer_matrices[i].bmm(joint_attention)
43
- return joint_attention
44
-
45
- class RPBertEmbeddings(BertEmbeddings):
46
- def __init__(self, config):
47
- super().__init__()
48
-
49
- self.add1 = Add()
50
- self.add2 = Add()
51
-
52
- def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
53
- if input_ids is not None:
54
- input_shape = input_ids.size()
55
- else:
56
- input_shape = inputs_embeds.size()[:-1]
57
-
58
- seq_length = input_shape[1]
59
-
60
- if position_ids is None:
61
- position_ids = self.position_ids[:, :seq_length]
62
-
63
- if token_type_ids is None:
64
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
65
-
66
- if inputs_embeds is None:
67
- inputs_embeds = self.word_embeddings(input_ids)
68
- position_embeddings = self.position_embeddings(position_ids)
69
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
70
-
71
- # embeddings = inputs_embeds + position_embeddings + token_type_embeddings
72
- embeddings = self.add1([token_type_embeddings, position_embeddings])
73
- embeddings = self.add2([embeddings, inputs_embeds])
74
- embeddings = self.LayerNorm(embeddings)
75
- embeddings = self.dropout(embeddings)
76
- return embeddings
77
-
78
- def relprop(self, cam, **kwargs):
79
- cam = self.dropout.relprop(cam, **kwargs)
80
- cam = self.LayerNorm.relprop(cam, **kwargs)
81
-
82
- # [inputs_embeds, position_embeddings, token_type_embeddings]
83
- (cam) = self.add2.relprop(cam, **kwargs)
84
-
85
- return cam
86
-
87
- class RPBertEncoder(transformers.modeling_bert.BertEncoder):
88
- def __init__(self, config):
89
- super().__init__()
90
- self.config = config
91
- self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
92
-
93
- def relprop(self, cam, **kwargs):
94
- # assuming output_hidden_states is False
95
- for layer_module in reversed(self.layer):
96
- cam = layer_module.relprop(cam, **kwargs)
97
- return cam
98
-
99
-
100
- # not adding relprop since this is only pooling at the end of the network, does not impact tokens importance
101
- class RPBertPooler(transformers.modeling_bert.BertPooler):
102
- def __init__(self, config):
103
- super().__init__()
104
- self.pool = IndexSelect()
105
-
106
- def forward(self, hidden_states):
107
- # We "pool" the model by simply taking the hidden state corresponding
108
- # to the first token.
109
- self._seq_size = hidden_states.shape[1]
110
-
111
- # first_token_tensor = hidden_states[:, 0]
112
- first_token_tensor = self.pool(hidden_states, 1, torch.tensor(0, device=hidden_states.device))
113
- first_token_tensor = first_token_tensor.squeeze(1)
114
- pooled_output = self.dense(first_token_tensor)
115
- pooled_output = self.activation(pooled_output)
116
- return pooled_output
117
-
118
- def relprop(self, cam, **kwargs):
119
- cam = self.activation.relprop(cam, **kwargs)
120
- #print(cam.sum())
121
- cam = self.dense.relprop(cam, **kwargs)
122
- #print(cam.sum())
123
- cam = cam.unsqueeze(1)
124
- cam = self.pool.relprop(cam, **kwargs)
125
- #print(cam.sum())
126
-
127
- return cam
128
-
129
- class BertAttention(transformers.modeling_bert.BertAttention):
130
- def __init__(self, config):
131
- super().__init__()
132
- self.clone = Clone()
133
-
134
- def forward(
135
- self,
136
- hidden_states,
137
- attention_mask=None,
138
- head_mask=None,
139
- encoder_hidden_states=None,
140
- encoder_attention_mask=None,
141
- output_attentions=False,
142
- ):
143
- h1, h2 = self.clone(hidden_states, 2)
144
- self_outputs = self.self(
145
- h1,
146
- attention_mask,
147
- head_mask,
148
- encoder_hidden_states,
149
- encoder_attention_mask,
150
- output_attentions,
151
- )
152
- attention_output = self.output(self_outputs[0], h2)
153
- outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
154
- return outputs
155
-
156
- def relprop(self, cam, **kwargs):
157
- # assuming that we don't ouput the attentions (outputs = (attention_output,)), self_outputs=(context_layer,)
158
- (cam1, cam2) = self.output.relprop(cam, **kwargs)
159
- #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
160
- cam1 = self.self.relprop(cam1, **kwargs)
161
- #print(cam1.sum(), cam2.sum(), (cam1 + cam2).sum())
162
-
163
- return self.clone.relprop((cam1, cam2), **kwargs)
164
-
165
- class BertSelfAttention(transformers.modeling_bert.BertSelfAttention):
166
- def __init__(self, config):
167
- super().__init__()
168
-
169
- self.matmul1 = MatMul()
170
- self.matmul2 = MatMul()
171
- self.softmax = Softmax(dim=-1)
172
- self.add = Add()
173
- self.mul = Mul()
174
- self.head_mask = None
175
- self.attention_mask = None
176
- self.clone = Clone()
177
-
178
- self.attn_cam = None
179
- self.attn = None
180
- self.attn_gradients = None
181
-
182
- def get_attn(self):
183
- return self.attn
184
-
185
- def save_attn(self, attn):
186
- self.attn = attn
187
-
188
- def save_attn_cam(self, cam):
189
- self.attn_cam = cam
190
-
191
- def get_attn_cam(self):
192
- return self.attn_cam
193
-
194
- def save_attn_gradients(self, attn_gradients):
195
- self.attn_gradients = attn_gradients
196
-
197
- def get_attn_gradients(self):
198
- return self.attn_gradients
199
-
200
- def transpose_for_scores_relprop(self, x):
201
- return x.permute(0, 2, 1, 3).flatten(2)
202
-
203
- def forward(
204
- self,
205
- hidden_states,
206
- attention_mask=None,
207
- head_mask=None,
208
- encoder_hidden_states=None,
209
- encoder_attention_mask=None,
210
- output_attentions=False,
211
- ):
212
- self.head_mask = head_mask
213
- self.attention_mask = attention_mask
214
-
215
- h1, h2, h3 = self.clone(hidden_states, 3)
216
- mixed_query_layer = self.query(h1)
217
-
218
- # If this is instantiated as a cross-attention module, the keys
219
- # and values come from an encoder; the attention mask needs to be
220
- # such that the encoder's padding tokens are not attended to.
221
- if encoder_hidden_states is not None:
222
- mixed_key_layer = self.key(encoder_hidden_states)
223
- mixed_value_layer = self.value(encoder_hidden_states)
224
- attention_mask = encoder_attention_mask
225
- else:
226
- mixed_key_layer = self.key(h2)
227
- mixed_value_layer = self.value(h3)
228
-
229
- query_layer = self.transpose_for_scores(mixed_query_layer)
230
- key_layer = self.transpose_for_scores(mixed_key_layer)
231
- value_layer = self.transpose_for_scores(mixed_value_layer)
232
-
233
- # Take the dot product between "query" and "key" to get the raw attention scores.
234
- attention_scores = self.matmul1([query_layer, key_layer.transpose(-1, -2)])
235
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
236
- if attention_mask is not None:
237
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
238
- attention_scores = self.add([attention_scores, attention_mask])
239
-
240
- # Normalize the attention scores to probabilities.
241
- attention_probs = self.softmax(attention_scores)
242
-
243
- self.save_attn(attention_probs)
244
- attention_probs.register_hook(self.save_attn_gradients)
245
-
246
- # This is actually dropping out entire tokens to attend to, which might
247
- # seem a bit unusual, but is taken from the original Transformer paper.
248
- attention_probs = self.dropout(attention_probs)
249
-
250
- # Mask heads if we want to
251
- if head_mask is not None:
252
- attention_probs = attention_probs * head_mask
253
-
254
- context_layer = self.matmul2([attention_probs, value_layer])
255
-
256
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
257
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
258
- context_layer = context_layer.view(*new_context_layer_shape)
259
-
260
- outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
261
- return outputs
262
-
263
- def relprop(self, cam, **kwargs):
264
- # Assume output_attentions == False
265
- cam = self.transpose_for_scores(cam)
266
-
267
- # [attention_probs, value_layer]
268
- (cam1, cam2) = self.matmul2.relprop(cam, **kwargs)
269
- cam1 /= 2
270
- cam2 /= 2
271
- if self.head_mask is not None:
272
- # [attention_probs, head_mask]
273
- (cam1, _)= self.mul.relprop(cam1, **kwargs)
274
-
275
-
276
- self.save_attn_cam(cam1)
277
-
278
- cam1 = self.dropout.relprop(cam1, **kwargs)
279
-
280
- cam1 = self.softmax.relprop(cam1, **kwargs)
281
-
282
- if self.attention_mask is not None:
283
- # [attention_scores, attention_mask]
284
- (cam1, _) = self.add.relprop(cam1, **kwargs)
285
-
286
- # [query_layer, key_layer.transpose(-1, -2)]
287
- (cam1_1, cam1_2) = self.matmul1.relprop(cam1, **kwargs)
288
- cam1_1 /= 2
289
- cam1_2 /= 2
290
-
291
- # query
292
- cam1_1 = self.transpose_for_scores_relprop(cam1_1)
293
- cam1_1 = self.query.relprop(cam1_1, **kwargs)
294
-
295
- # key
296
- cam1_2 = self.transpose_for_scores_relprop(cam1_2.transpose(-1, -2))
297
- cam1_2 = self.key.relprop(cam1_2, **kwargs)
298
-
299
- # value
300
- cam2 = self.transpose_for_scores_relprop(cam2)
301
- cam2 = self.value.relprop(cam2, **kwargs)
302
-
303
- cam = self.clone.relprop((cam1_1, cam1_2, cam2), **kwargs)
304
-
305
- return cam
306
-
307
-
308
- class BertSelfOutput(transformers.modeling_bert.BertSelfOutput):
309
- def __init__(self, config):
310
- super().__init__()
311
- self.add = Add()
312
-
313
- def forward(self, hidden_states, input_tensor):
314
- hidden_states = self.dense(hidden_states)
315
- hidden_states = self.dropout(hidden_states)
316
- add = self.add([hidden_states, input_tensor])
317
- hidden_states = self.LayerNorm(add)
318
- return hidden_states
319
-
320
- def relprop(self, cam, **kwargs):
321
- cam = self.LayerNorm.relprop(cam, **kwargs)
322
- # [hidden_states, input_tensor]
323
- (cam1, cam2) = self.add.relprop(cam, **kwargs)
324
- cam1 = self.dropout.relprop(cam1, **kwargs)
325
- cam1 = self.dense.relprop(cam1, **kwargs)
326
-
327
- return (cam1, cam2)
328
-
329
-
330
- class BertIntermediate(transformers.modeling_bert.BertIntermediate):
331
- def relprop(self, cam, **kwargs):
332
- cam = self.intermediate_act_fn.relprop(cam, **kwargs) # FIXME only ReLU
333
- #print(cam.sum())
334
- cam = self.dense.relprop(cam, **kwargs)
335
- #print(cam.sum())
336
- return cam
337
-
338
-
339
- class BertOutput(transformers.modeling_bert.BertOutput):
340
- def __init__(self, config):
341
- super().__init__()
342
- self.add = Add()
343
-
344
- def forward(self, hidden_states, input_tensor):
345
- hidden_states = self.dense(hidden_states)
346
- hidden_states = self.dropout(hidden_states)
347
- add = self.add([hidden_states, input_tensor])
348
- hidden_states = self.LayerNorm(add)
349
- return hidden_states
350
-
351
- def relprop(self, cam, **kwargs):
352
- # print("in", cam.sum())
353
- cam = self.LayerNorm.relprop(cam, **kwargs)
354
- #print(cam.sum())
355
- # [hidden_states, input_tensor]
356
- (cam1, cam2)= self.add.relprop(cam, **kwargs)
357
- # print("add", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
358
- cam1 = self.dropout.relprop(cam1, **kwargs)
359
- #print(cam1.sum())
360
- cam1 = self.dense.relprop(cam1, **kwargs)
361
- # print("dense", cam1.sum())
362
-
363
- # print("out", cam1.sum() + cam2.sum(), cam1.sum(), cam2.sum())
364
- return (cam1, cam2)
365
-
366
-
367
- class RPBertLayer(nn.Module):
368
- def __init__(self, config):
369
- super().__init__()
370
- self.attention = BertAttention(config)
371
- self.intermediate = BertIntermediate(config)
372
- self.output = BertOutput(config)
373
- self.clone = Clone()
374
-
375
- def forward(
376
- self,
377
- hidden_states,
378
- attention_mask=None,
379
- head_mask=None,
380
- output_attentions=False,
381
- ):
382
- self_attention_outputs = self.attention(
383
- hidden_states,
384
- attention_mask,
385
- head_mask,
386
- output_attentions=output_attentions,
387
- )
388
- attention_output = self_attention_outputs[0]
389
- outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
390
-
391
- ao1, ao2 = self.clone(attention_output, 2)
392
- intermediate_output = self.intermediate(ao1)
393
- layer_output = self.output(intermediate_output, ao2)
394
-
395
- outputs = (layer_output,) + outputs
396
- return outputs
397
-
398
- def relprop(self, cam, **kwargs):
399
- (cam1, cam2) = self.output.relprop(cam, **kwargs)
400
- # print("output", cam1.sum(), cam2.sum(), cam1.sum() + cam2.sum())
401
- cam1 = self.intermediate.relprop(cam1, **kwargs)
402
- # print("intermediate", cam1.sum())
403
- cam = self.clone.relprop((cam1, cam2), **kwargs)
404
- # print("clone", cam.sum())
405
- cam = self.attention.relprop(cam, **kwargs)
406
- # print("attention", cam.sum())
407
- return cam
408
-
409
-
410
- class BertModel(BertPreTrainedModel):
411
- def __init__(self, config):
412
- super().__init__(config)
413
- self.config = config
414
-
415
- self.embeddings = BertEmbeddings(config)
416
- self.encoder = BertEncoder(config)
417
- self.pooler = BertPooler(config)
418
-
419
- self.init_weights()
420
-
421
- def get_input_embeddings(self):
422
- return self.embeddings.word_embeddings
423
-
424
- def set_input_embeddings(self, value):
425
- self.embeddings.word_embeddings = value
426
-
427
- def forward(
428
- self,
429
- input_ids=None,
430
- attention_mask=None,
431
- token_type_ids=None,
432
- position_ids=None,
433
- head_mask=None,
434
- inputs_embeds=None,
435
- encoder_hidden_states=None,
436
- encoder_attention_mask=None,
437
- output_attentions=None,
438
- output_hidden_states=None,
439
- return_dict=None,
440
- ):
441
- r"""
442
- encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
443
- Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
444
- if the model is configured as a decoder.
445
- encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
446
- Mask to avoid performing attention on the padding token indices of the encoder input. This mask
447
- is used in the cross-attention if the model is configured as a decoder.
448
- Mask values selected in ``[0, 1]``:
449
- ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
450
- """
451
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
452
- output_hidden_states = (
453
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
454
- )
455
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
456
-
457
- if input_ids is not None and inputs_embeds is not None:
458
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
459
- elif input_ids is not None:
460
- input_shape = input_ids.size()
461
- elif inputs_embeds is not None:
462
- input_shape = inputs_embeds.size()[:-1]
463
- else:
464
- raise ValueError("You have to specify either input_ids or inputs_embeds")
465
-
466
- device = input_ids.device if input_ids is not None else inputs_embeds.device
467
-
468
- if attention_mask is None:
469
- attention_mask = torch.ones(input_shape, device=device)
470
- if token_type_ids is None:
471
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
472
-
473
- # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
474
- # ourselves in which case we just need to make it broadcastable to all heads.
475
- extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
476
-
477
- # If a 2D or 3D attention mask is provided for the cross-attention
478
- # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
479
- if self.config.is_decoder and encoder_hidden_states is not None:
480
- encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
481
- encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
482
- if encoder_attention_mask is None:
483
- encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
484
- encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
485
- else:
486
- encoder_extended_attention_mask = None
487
-
488
- # Prepare head mask if needed
489
- # 1.0 in head_mask indicate we keep the head
490
- # attention_probs has shape bsz x n_heads x N x N
491
- # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
492
- # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
493
- head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
494
-
495
- embedding_output = self.embeddings(
496
- input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
497
- )
498
-
499
- encoder_outputs = self.encoder(
500
- embedding_output,
501
- attention_mask=extended_attention_mask,
502
- head_mask=head_mask,
503
- encoder_hidden_states=encoder_hidden_states,
504
- encoder_attention_mask=encoder_extended_attention_mask,
505
- output_attentions=output_attentions,
506
- output_hidden_states=output_hidden_states,
507
- return_dict=return_dict,
508
- )
509
- sequence_output = encoder_outputs[0]
510
- pooled_output = self.pooler(sequence_output)
511
-
512
- if not return_dict:
513
- return (sequence_output, pooled_output) + encoder_outputs[1:]
514
-
515
- return BaseModelOutputWithPooling(
516
- last_hidden_state=sequence_output,
517
- pooler_output=pooled_output,
518
- hidden_states=encoder_outputs.hidden_states,
519
- attentions=encoder_outputs.attentions,
520
- )
521
-
522
- def relprop(self, cam, **kwargs):
523
- cam = self.pooler.relprop(cam, **kwargs)
524
- # print("111111111111",cam.sum())
525
- cam = self.encoder.relprop(cam, **kwargs)
526
- # print("222222222222222", cam.sum())
527
- # print("conservation: ", cam.sum())
528
- return cam
529
-
530
-
531
- transformers.modeling_bert.BertEmbeddings = RPBertEmbeddings
532
- transformers.modeling_bert.BertEncoder = RPBertEncoder
533
-
534
- if __name__ == '__main__':
535
- class Config:
536
- def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
537
- self.hidden_size = hidden_size
538
- self.num_attention_heads = num_attention_heads
539
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
540
-
541
- model = BertSelfAttention(Config(1024, 4, 0.1))
542
- x = torch.rand(2, 20, 1024)
543
- x.requires_grad_()
544
-
545
- model.eval()
546
-
547
- y = model.forward(x)
548
-
549
- relprop = model.relprop(torch.rand(2, 20, 1024), (torch.rand(2, 20, 1024),))
550
-
551
- print(relprop[1][0].shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
lib/roberta2.py.rej DELETED
@@ -1,63 +0,0 @@
1
- --- modeling_roberta.py 2022-06-28 11:59:19.974278244 +0200
2
- +++ roberta2.py 2022-06-28 14:13:05.765050058 +0200
3
- @@ -23,14 +23,14 @@
4
- from torch import nn
5
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
6
-
7
- -from ...activations import ACT2FN, gelu
8
- -from ...file_utils import (
9
- +from transformers.activations import ACT2FN, gelu
10
- +from transformers.file_utils import (
11
- add_code_sample_docstrings,
12
- add_start_docstrings,
13
- add_start_docstrings_to_model_forward,
14
- replace_return_docstrings,
15
- )
16
- -from ...modeling_outputs import (
17
- +from transformers.modeling_outputs import (
18
- BaseModelOutputWithPastAndCrossAttentions,
19
- BaseModelOutputWithPoolingAndCrossAttentions,
20
- CausalLMOutputWithCrossAttentions,
21
- @@ -40,14 +40,14 @@
22
- SequenceClassifierOutput,
23
- TokenClassifierOutput,
24
- )
25
- -from ...modeling_utils import (
26
- +from transformers.modeling_utils import (
27
- PreTrainedModel,
28
- apply_chunking_to_forward,
29
- find_pruneable_heads_and_indices,
30
- prune_linear_layer,
31
- )
32
- -from ...utils import logging
33
- -from .configuration_roberta import RobertaConfig
34
- +from transformers.utils import logging
35
- +from transformers.models.roberta.configuration_roberta import RobertaConfig
36
-
37
-
38
- logger = logging.get_logger(__name__)
39
- @@ -183,6 +183,24 @@
40
-
41
- self.is_decoder = config.is_decoder
42
-
43
- + def get_attn(self):
44
- + return self.attn
45
- +
46
- + def save_attn(self, attn):
47
- + self.attn = attn
48
- +
49
- + def save_attn_cam(self, cam):
50
- + self.attn_cam = cam
51
- +
52
- + def get_attn_cam(self):
53
- + return self.attn_cam
54
- +
55
- + def save_attn_gradients(self, attn_gradients):
56
- + self.attn_gradients = attn_gradients
57
- +
58
- + def get_attn_gradients(self):
59
- + return self.attn_gradients
60
- +
61
- def transpose_for_scores(self, x):
62
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
63
- x = x.view(*new_x_shape)