AmanPriyanshu commited on
Commit
61c72d7
1 Parent(s): 0df9f76

Setting-up-repo

Browse files
bottleneck_t5.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import CrossEntropyLoss
5
+ from torch.nn import functional as F
6
+ from typing import Optional, Tuple, Union
7
+
8
+ from transformers import T5Config, T5Tokenizer, T5ForConditionalGeneration
9
+ from transformers.models.t5.modeling_t5 import (
10
+ T5LayerNorm,
11
+ T5LayerFF,
12
+ T5Attention,
13
+ T5LayerSelfAttention,
14
+ T5LayerCrossAttention,
15
+ T5Block,
16
+ T5Stack,
17
+ )
18
+ from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
19
+
20
+ class BottleneckCrossAttentionGate(nn.Module):
21
+ def __init__(self, config):
22
+ super().__init__()
23
+ self.gate = nn.Linear(2 * config.d_model, config.d_model, bias=False)
24
+ self.act = nn.Sigmoid()
25
+
26
+ def forward(self, query_states, latents):
27
+ latents = latents.unsqueeze(1).expand(query_states.shape)
28
+ query_latents = torch.cat([query_states, latents], dim=-1)
29
+ return 2 * self.act(self.gate(query_latents))
30
+
31
+ class BottleneckT5Attention(T5Attention):
32
+ def __init__(self, config: T5Config, has_relative_attention_bias=False):
33
+ super(T5Attention, self).__init__()
34
+ self.is_decoder = config.is_decoder
35
+ self.has_relative_attention_bias = has_relative_attention_bias
36
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
37
+ self.relative_attention_max_distance = config.relative_attention_max_distance
38
+ self.d_model = config.d_model
39
+ self.key_value_proj_dim = config.d_kv
40
+ self.n_heads = config.num_heads
41
+ self.dropout = config.dropout_rate
42
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
43
+
44
+ # Mesh TensorFlow initialization to avoid scaling before softmax
45
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
46
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
47
+
48
+ if self.has_relative_attention_bias:
49
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
50
+ self.pruned_heads = set()
51
+ self.gradient_checkpointing = False
52
+
53
+ def prune_heads(self, heads):
54
+ if len(heads) == 0:
55
+ return
56
+ heads, index = find_pruneable_heads_and_indices(
57
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
58
+ )
59
+ # Prune linear layers
60
+ self.v = prune_linear_layer(self.v, index)
61
+ self.o = prune_linear_layer(self.o, index, dim=1)
62
+ # Update hyper params
63
+ self.n_heads = self.n_heads - len(heads)
64
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
65
+ self.pruned_heads = self.pruned_heads.union(heads)
66
+
67
+ def forward(
68
+ self,
69
+ hidden_states,
70
+ mask=None,
71
+ key_value_states=None,
72
+ position_bias=None,
73
+ past_key_value=None,
74
+ layer_head_mask=None,
75
+ query_length=None,
76
+ use_cache=False,
77
+ output_attentions=False,
78
+ ):
79
+ """
80
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
81
+ """
82
+ # Input is (batch_size, seq_length, dim)
83
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
84
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
85
+ batch_size, seq_length = hidden_states.shape[:2]
86
+
87
+ real_seq_length = seq_length
88
+
89
+ if past_key_value is not None:
90
+ assert (
91
+ len(past_key_value) == 2
92
+ ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
93
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
94
+
95
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
96
+
97
+ def shape(states):
98
+ """projection"""
99
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
100
+
101
+ def unshape(states):
102
+ """reshape"""
103
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
104
+
105
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
106
+ """projects hidden states correctly to key/query states"""
107
+ if key_value_states is None:
108
+ # self-attn
109
+ # (batch_size, n_heads, seq_length, dim_per_head)
110
+ hidden_states = shape(proj_layer(hidden_states))
111
+ elif past_key_value is None:
112
+ # cross-attn
113
+ # (batch_size, n_heads, seq_length, dim_per_head)
114
+ hidden_states = shape(proj_layer(key_value_states))
115
+
116
+ if past_key_value is not None:
117
+ if key_value_states is None:
118
+ # self-attn
119
+ # (batch_size, n_heads, key_length, dim_per_head)
120
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
121
+ else:
122
+ # cross-attn
123
+ hidden_states = past_key_value
124
+ return hidden_states
125
+
126
+ # key/value states
127
+ key_states = torch.zeros((batch_size, self.n_heads, seq_length, key_length), device=hidden_states.device)
128
+ value_states = project(
129
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
130
+ )
131
+
132
+ # compute scores
133
+ scores = torch.ones((batch_size, self.n_heads, seq_length, key_length), device=hidden_states.device)
134
+
135
+ if position_bias is None:
136
+ if not self.has_relative_attention_bias:
137
+ position_bias = torch.zeros(
138
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
139
+ )
140
+ if self.gradient_checkpointing and self.training:
141
+ position_bias.requires_grad = True
142
+ else:
143
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
144
+
145
+ # if key and values are already calculated
146
+ # we want only the last query position bias
147
+ if past_key_value is not None:
148
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
149
+
150
+ if mask is not None:
151
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
152
+
153
+ if self.pruned_heads:
154
+ mask = torch.ones(position_bias.shape[1])
155
+ mask[list(self.pruned_heads)] = 0
156
+ position_bias_masked = position_bias[:, mask.bool()]
157
+ else:
158
+ position_bias_masked = position_bias
159
+
160
+ scores += position_bias_masked
161
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
162
+ scores
163
+ ) # (batch_size, n_heads, seq_length, key_length)
164
+ attn_weights = nn.functional.dropout(
165
+ attn_weights, p=self.dropout, training=self.training
166
+ ) # (batch_size, n_heads, seq_length, key_length)
167
+
168
+ # Mask heads if we want to
169
+ if layer_head_mask is not None:
170
+ attn_weights = attn_weights * layer_head_mask
171
+
172
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
173
+ attn_output = self.o(attn_output)
174
+
175
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
176
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
177
+
178
+ if output_attentions:
179
+ outputs = outputs + (attn_weights,)
180
+ return outputs
181
+
182
+ class BottleneckT5LayerCrossAttention(T5LayerCrossAttention):
183
+ def __init__(self, config):
184
+ super(T5LayerCrossAttention, self).__init__()
185
+ self.EncDecAttention = BottleneckT5Attention(config, has_relative_attention_bias=False)
186
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
187
+ self.gate = BottleneckCrossAttentionGate(config)
188
+ self.dropout = nn.Dropout(config.dropout_rate)
189
+
190
+ def forward(
191
+ self,
192
+ hidden_states,
193
+ key_value_states,
194
+ attention_mask=None,
195
+ position_bias=None,
196
+ layer_head_mask=None,
197
+ past_key_value=None,
198
+ use_cache=False,
199
+ query_length=None,
200
+ output_attentions=False,
201
+ ):
202
+ normed_hidden_states = self.layer_norm(hidden_states)
203
+ attention_output = self.EncDecAttention(
204
+ normed_hidden_states,
205
+ mask=attention_mask,
206
+ key_value_states=key_value_states,
207
+ position_bias=position_bias,
208
+ layer_head_mask=layer_head_mask,
209
+ past_key_value=past_key_value,
210
+ use_cache=use_cache,
211
+ query_length=query_length,
212
+ output_attentions=output_attentions,
213
+ )
214
+ latents = key_value_states[:, 0]
215
+ layer_output = hidden_states + self.dropout(self.gate(normed_hidden_states, latents) * attention_output[0])
216
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
217
+ return outputs
218
+
219
+ class BottleneckT5Block(T5Block):
220
+ def __init__(self, config, has_relative_attention_bias=False):
221
+ super(T5Block, self).__init__()
222
+ self.is_decoder = config.is_decoder
223
+ self.layer = nn.ModuleList()
224
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
225
+ if self.is_decoder:
226
+ self.layer.append(BottleneckT5LayerCrossAttention(config))
227
+
228
+ self.layer.append(T5LayerFF(config))
229
+
230
+ class BottleneckT5Stack(T5Stack):
231
+ def __init__(self, config, embed_tokens=None):
232
+ super(T5Stack, self).__init__(config)
233
+
234
+ self.embed_tokens = embed_tokens
235
+ self.is_decoder = config.is_decoder
236
+
237
+ self.block = nn.ModuleList(
238
+ [BottleneckT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
239
+ )
240
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
241
+ self.dropout = nn.Dropout(config.dropout_rate)
242
+
243
+ # Initialize weights and apply final processing
244
+ self.post_init()
245
+ # Model parallel
246
+ self.model_parallel = False
247
+ self.device_map = None
248
+ self.gradient_checkpointing = False
249
+
250
+ class BottleneckT5LMWithPerturb(T5ForConditionalGeneration):
251
+ def __init__(self, config: T5Config):
252
+ super(T5ForConditionalGeneration, self).__init__(config)
253
+ self.model_dim = config.d_model
254
+
255
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
256
+ encoder_config = copy.deepcopy(config)
257
+ encoder_config.is_decoder = False
258
+ encoder_config.use_cache = False
259
+ encoder_config.is_encoder_decoder = False
260
+ self.encoder = T5Stack(encoder_config, self.shared)
261
+
262
+ # New in Contra: MHA bottleneck block
263
+ self.num_heads = config.num_heads
264
+ self.bottleneck = nn.MultiheadAttention(config.d_model,
265
+ num_heads=config.num_heads,
266
+ dropout=config.dropout_rate,
267
+ bias=False,
268
+ batch_first=True)
269
+ self.bottleneck_scale = nn.Parameter(torch.ones(1))
270
+
271
+ self.dec_emb = nn.Embedding(config.vocab_size, config.d_model)
272
+ decoder_config = copy.deepcopy(config)
273
+ decoder_config.is_decoder = True
274
+ decoder_config.is_encoder_decoder = False
275
+ decoder_config.num_layers = config.num_decoder_layers
276
+ self.decoder = BottleneckT5Stack(decoder_config, self.dec_emb)
277
+
278
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
279
+
280
+ # Initialize weights and apply final processing
281
+ self.post_init()
282
+
283
+ # Model parallel
284
+ self.model_parallel = False
285
+ self.device_map = None
286
+
287
+ def forward(
288
+ self,
289
+ input_ids: Optional[torch.LongTensor] = None,
290
+ attention_mask: Optional[torch.FloatTensor] = None,
291
+ decoder_input_ids: Optional[torch.LongTensor] = None,
292
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
293
+ head_mask: Optional[torch.FloatTensor] = None,
294
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
295
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
296
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
297
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
298
+ inputs_embeds: Optional[torch.FloatTensor] = None,
299
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
300
+ labels: Optional[torch.LongTensor] = None,
301
+ use_cache: Optional[bool] = None,
302
+ output_attentions: Optional[bool] = None,
303
+ output_hidden_states: Optional[bool] = None,
304
+ return_dict: Optional[bool] = None,
305
+ perturb_vector: Optional[torch.FloatTensor] = None,
306
+ encode_only: Optional[bool] = None,
307
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
308
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
309
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
310
+
311
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
312
+ if head_mask is not None and decoder_head_mask is None:
313
+ if self.config.num_layers == self.config.num_decoder_layers:
314
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
315
+ decoder_head_mask = head_mask
316
+
317
+ # Encode if needed (training, first prediction pass)
318
+ if encoder_outputs is None:
319
+ # Convert encoder inputs in embeddings if needed
320
+ encoder_outputs = self.encoder(
321
+ input_ids=input_ids,
322
+ attention_mask=attention_mask,
323
+ inputs_embeds=inputs_embeds,
324
+ head_mask=head_mask,
325
+ output_attentions=output_attentions,
326
+ output_hidden_states=output_hidden_states,
327
+ return_dict=return_dict,
328
+ )
329
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
330
+ encoder_outputs = BaseModelOutput(
331
+ last_hidden_state=encoder_outputs[0],
332
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
333
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
334
+ )
335
+
336
+ hidden_states = encoder_outputs[0]
337
+
338
+ # MHA across token embeddings + embedding normalization + broadcast
339
+ hidden_states = hidden_states.repeat(
340
+ attention_mask.shape[0] // hidden_states.shape[0],
341
+ 1, 1) # during contrastive search, attn mask can have higher batch size than hidden_state
342
+ mask_expanded = attention_mask.float().unsqueeze(-1).expand(hidden_states.shape)
343
+ mean_pooled_embedding = torch.sum(hidden_states * mask_expanded, 1) / torch.clamp(mask_expanded.sum(1), min=1e-9)
344
+ unscaled_latent, attn_weights = self.bottleneck(mean_pooled_embedding.unsqueeze(1), hidden_states, hidden_states,
345
+ need_weights=False,
346
+ # torch MHA attn_mask has opposite signs to HF T5 masks... sigh
347
+ attn_mask=attention_mask.float().unsqueeze(1).repeat_interleave(self.num_heads, dim=0))
348
+ latent = self.bottleneck_scale * F.normalize(unscaled_latent, p=2, dim=2)
349
+ if encode_only:
350
+ return latent.squeeze(1)
351
+ hidden_states = latent.expand(hidden_states.shape)
352
+
353
+ if hasattr(self, 'perturb_vector'):
354
+ hidden_states = self.bottleneck_scale * F.normalize(hidden_states + self.perturb_vector, p=2, dim=2)
355
+
356
+ if self.model_parallel:
357
+ torch.cuda.set_device(self.decoder.first_device)
358
+
359
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
360
+ # get decoder inputs from shifting lm labels to the right
361
+ decoder_input_ids = self._shift_right(labels)
362
+
363
+ # Set device for model parallelism
364
+ if self.model_parallel:
365
+ torch.cuda.set_device(self.decoder.first_device)
366
+ hidden_states = hidden_states.to(self.decoder.first_device)
367
+ if decoder_input_ids is not None:
368
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
369
+ if attention_mask is not None:
370
+ attention_mask = attention_mask.to(self.decoder.first_device)
371
+ if decoder_attention_mask is not None:
372
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
373
+
374
+ # Decode
375
+ decoder_outputs = self.decoder(
376
+ input_ids=decoder_input_ids,
377
+ attention_mask=decoder_attention_mask,
378
+ inputs_embeds=decoder_inputs_embeds,
379
+ past_key_values=past_key_values,
380
+ encoder_hidden_states=hidden_states,
381
+ encoder_attention_mask=attention_mask,
382
+ head_mask=decoder_head_mask,
383
+ cross_attn_head_mask=cross_attn_head_mask,
384
+ use_cache=use_cache,
385
+ output_attentions=output_attentions,
386
+ output_hidden_states=output_hidden_states,
387
+ return_dict=return_dict,
388
+ )
389
+
390
+ sequence_output = decoder_outputs[0]
391
+
392
+ # Set device for model parallelism
393
+ if self.model_parallel:
394
+ torch.cuda.set_device(self.encoder.first_device)
395
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
396
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
397
+
398
+ if self.config.tie_word_embeddings:
399
+ # Rescale output before projecting on vocab
400
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
401
+ sequence_output = sequence_output * (self.model_dim**-0.5)
402
+
403
+ lm_logits = self.lm_head(sequence_output)
404
+
405
+ loss = None
406
+ if labels is not None:
407
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
408
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
409
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
410
+
411
+ if not return_dict:
412
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
413
+ return ((loss,) + output) if loss is not None else output
414
+
415
+ return Seq2SeqLMOutput(
416
+ loss=loss,
417
+ logits=lm_logits,
418
+ past_key_values=decoder_outputs.past_key_values,
419
+ decoder_hidden_states=decoder_outputs.hidden_states,
420
+ decoder_attentions=decoder_outputs.attentions,
421
+ cross_attentions=decoder_outputs.cross_attentions,
422
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
423
+ encoder_hidden_states=encoder_outputs.hidden_states,
424
+ encoder_attentions=encoder_outputs.attentions,
425
+ )
426
+
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./bottleneck-t5",
3
+ "architectures": [
4
+ "BottleneckT5LMWithPerturb"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModelForCausalLM": "bottleneck_t5.BottleneckT5LMWithPerturb"
8
+ },
9
+ "classifier_dropout": 0.0,
10
+ "d_ff": 2816,
11
+ "d_kv": 64,
12
+ "d_model": 1024,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "gelu_new",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "gated-gelu",
18
+ "initializer_factor": 1.0,
19
+ "is_encoder_decoder": true,
20
+ "is_gated_act": true,
21
+ "layer_norm_epsilon": 1e-06,
22
+ "model_type": "t5",
23
+ "num_decoder_layers": 24,
24
+ "num_heads": 16,
25
+ "num_layers": 24,
26
+ "output_past": true,
27
+ "pad_token_id": 0,
28
+ "relative_attention_max_distance": 128,
29
+ "relative_attention_num_buckets": 32,
30
+ "tie_word_embeddings": false,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.33.3",
33
+ "use_cache": true,
34
+ "vocab_size": 32128
35
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ff13b4e5cbef285c28eb1f0def157e18e39f556395f704b044b6d55c02547f7c
3
+ size 3281159797
special_tokens_map.json ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "eos_token": "</s>",
105
+ "pad_token": "<pad>",
106
+ "unk_token": "<unk>"
107
+ }
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d60acb128cf7b7f2536e8f38a5b18a05535c9e14c7a355904270e15b0945ea86
3
+ size 791656
tokenizer_config.json ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "<extra_id_0>",
4
+ "<extra_id_1>",
5
+ "<extra_id_2>",
6
+ "<extra_id_3>",
7
+ "<extra_id_4>",
8
+ "<extra_id_5>",
9
+ "<extra_id_6>",
10
+ "<extra_id_7>",
11
+ "<extra_id_8>",
12
+ "<extra_id_9>",
13
+ "<extra_id_10>",
14
+ "<extra_id_11>",
15
+ "<extra_id_12>",
16
+ "<extra_id_13>",
17
+ "<extra_id_14>",
18
+ "<extra_id_15>",
19
+ "<extra_id_16>",
20
+ "<extra_id_17>",
21
+ "<extra_id_18>",
22
+ "<extra_id_19>",
23
+ "<extra_id_20>",
24
+ "<extra_id_21>",
25
+ "<extra_id_22>",
26
+ "<extra_id_23>",
27
+ "<extra_id_24>",
28
+ "<extra_id_25>",
29
+ "<extra_id_26>",
30
+ "<extra_id_27>",
31
+ "<extra_id_28>",
32
+ "<extra_id_29>",
33
+ "<extra_id_30>",
34
+ "<extra_id_31>",
35
+ "<extra_id_32>",
36
+ "<extra_id_33>",
37
+ "<extra_id_34>",
38
+ "<extra_id_35>",
39
+ "<extra_id_36>",
40
+ "<extra_id_37>",
41
+ "<extra_id_38>",
42
+ "<extra_id_39>",
43
+ "<extra_id_40>",
44
+ "<extra_id_41>",
45
+ "<extra_id_42>",
46
+ "<extra_id_43>",
47
+ "<extra_id_44>",
48
+ "<extra_id_45>",
49
+ "<extra_id_46>",
50
+ "<extra_id_47>",
51
+ "<extra_id_48>",
52
+ "<extra_id_49>",
53
+ "<extra_id_50>",
54
+ "<extra_id_51>",
55
+ "<extra_id_52>",
56
+ "<extra_id_53>",
57
+ "<extra_id_54>",
58
+ "<extra_id_55>",
59
+ "<extra_id_56>",
60
+ "<extra_id_57>",
61
+ "<extra_id_58>",
62
+ "<extra_id_59>",
63
+ "<extra_id_60>",
64
+ "<extra_id_61>",
65
+ "<extra_id_62>",
66
+ "<extra_id_63>",
67
+ "<extra_id_64>",
68
+ "<extra_id_65>",
69
+ "<extra_id_66>",
70
+ "<extra_id_67>",
71
+ "<extra_id_68>",
72
+ "<extra_id_69>",
73
+ "<extra_id_70>",
74
+ "<extra_id_71>",
75
+ "<extra_id_72>",
76
+ "<extra_id_73>",
77
+ "<extra_id_74>",
78
+ "<extra_id_75>",
79
+ "<extra_id_76>",
80
+ "<extra_id_77>",
81
+ "<extra_id_78>",
82
+ "<extra_id_79>",
83
+ "<extra_id_80>",
84
+ "<extra_id_81>",
85
+ "<extra_id_82>",
86
+ "<extra_id_83>",
87
+ "<extra_id_84>",
88
+ "<extra_id_85>",
89
+ "<extra_id_86>",
90
+ "<extra_id_87>",
91
+ "<extra_id_88>",
92
+ "<extra_id_89>",
93
+ "<extra_id_90>",
94
+ "<extra_id_91>",
95
+ "<extra_id_92>",
96
+ "<extra_id_93>",
97
+ "<extra_id_94>",
98
+ "<extra_id_95>",
99
+ "<extra_id_96>",
100
+ "<extra_id_97>",
101
+ "<extra_id_98>",
102
+ "<extra_id_99>"
103
+ ],
104
+ "clean_up_tokenization_spaces": true,
105
+ "eos_token": "</s>",
106
+ "extra_ids": 100,
107
+ "legacy": true,
108
+ "model_max_length": 512,
109
+ "pad_token": "<pad>",
110
+ "sp_model_kwargs": {},
111
+ "tokenizer_class": "T5Tokenizer",
112
+ "unk_token": "<unk>"
113
+ }
transformation_matrix_arxiv.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6e34e1eb3e82b49ceb8aa6c9cde4825f2a2ce83cf4b4ead27632f77a798bdf2
3
+ size 8389858
transformation_matrix_msd.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5b56abfd81f2d0e6110c9f64df34fbb9a41210dedf30e81d98aeeb92c772589
3
+ size 8389858
transformation_matrix_topicsum.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f9ee8806181b114c70688db269247bbdca4bd2e0d1bcbdb6e698b2c85748f97
3
+ size 8389858