tcheda commited on
Commit
81b992e
·
verified ·
1 Parent(s): 6e9b26e

"Test upload 1"

Browse files
Files changed (4) hide show
  1. config.json +37 -0
  2. configuration_mot.py +168 -0
  3. model.safetensors +3 -0
  4. modeling_mot.py +1621 -0
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "jaszczur/mixture_of_tokens",
3
+ "activation_function": "gelu_new",
4
+ "architectures": [
5
+ "MoTModel"
6
+ ],
7
+ "attn_pdrop": 0.1,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_mot.MoTConfig",
10
+ "AutoModel": "modeling_mot.MoTModel"
11
+ },
12
+ "bos_token_id": 50256,
13
+ "embd_pdrop": 0.1,
14
+ "emit_softmax_over_experts": false,
15
+ "eos_token_id": 50256,
16
+ "expert_size": 256,
17
+ "group_size": 32,
18
+ "init_scale": 1.0,
19
+ "initializer_range": 0.02,
20
+ "layer_norm_epsilon": 1e-05,
21
+ "model_type": "mot",
22
+ "n_embd": 512,
23
+ "n_expert": 256,
24
+ "n_head": 8,
25
+ "n_inner": 65536,
26
+ "n_layer": 8,
27
+ "n_positions": 1024,
28
+ "reorder_and_upcast_attn": false,
29
+ "resid_pdrop": 0.1,
30
+ "scale_attn_by_inverse_layer_idx": false,
31
+ "scale_attn_weights": true,
32
+ "torch_dtype": "float32",
33
+ "transformers_version": "4.42.0",
34
+ "use_cache": true,
35
+ "use_discrete_routing": false,
36
+ "vocab_size": 50257
37
+ }
configuration_mot.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """MixtureOfTokens configuration"""
17
+
18
+ from transformers import PretrainedConfig
19
+ from transformers.utils import logging
20
+
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ class MoTConfig(PretrainedConfig):
26
+ """
27
+ This is the configuration class to store the configuration of a [`MoTModel`]. It is used to
28
+ instantiate a MixtureOfTokens model according to the specified arguments, defining the model architecture. Instantiating a
29
+ configuration with the defaults will yield a similar configuration to that of the MixtureOfTokens
30
+ [mot](https://huggingface.co/mot) architecture.
31
+
32
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
33
+ documentation from [`PretrainedConfig`] for more information.
34
+
35
+
36
+ Args:
37
+ vocab_size (`int`, *optional*, defaults to 50257):
38
+ Vocabulary size of the MixtureOfTokens model. Defines the number of different tokens that can be represented by the
39
+ `inputs_ids` passed when calling [`MoTModel`].
40
+ n_positions (`int`, *optional*, defaults to 1024):
41
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
42
+ just in case (e.g., 512 or 1024 or 2048).
43
+ n_embd (`int`, *optional*, defaults to 768):
44
+ Dimensionality of the embeddings and hidden states.
45
+ n_layer (`int`, *optional*, defaults to 12):
46
+ Number of hidden layers in the Transformer encoder.
47
+ n_head (`int`, *optional*, defaults to 12):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ n_inner (`int`, *optional*):
50
+ Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
51
+ n_expert (`int`, *optional*, defaults to 32):
52
+ The number of experts.
53
+ group_size (`int`, *optional*, defaults to 32):
54
+ The number of tokens per expert.
55
+ expert_size (`int`, *optional*):
56
+ The dimensionality of an expert. `None` will set it to n_inner / n_head.
57
+ init_scale (`float`, *optional*, defaults to 1.0):
58
+ The scaling factor for the initialization of MoTMLP weights. Inactive when creating through `from_pretrained`.
59
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
60
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
61
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
62
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
63
+ embd_pdrop (`float`, *optional*, defaults to 0.1):
64
+ The dropout ratio for the embeddings.
65
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
66
+ The dropout ratio for the attention.
67
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
68
+ The epsilon to use in the layer normalization layers.
69
+ initializer_range (`float`, *optional*, defaults to 0.02):
70
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
71
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
72
+ Scale attention weights by dividing by sqrt(hidden_size)..
73
+ use_cache (`bool`, *optional*, defaults to `True`):
74
+ Whether or not the model should return the last key/values attentions (not used by all models).
75
+ bos_token_id (`int`, *optional*, defaults to 50256):
76
+ Id of the beginning of sentence token in the vocabulary.
77
+ eos_token_id (`int`, *optional*, defaults to 50256):
78
+ Id of the end of sentence token in the vocabulary.
79
+ scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
80
+ Whether to additionally scale attention weights by `1 / layer_idx + 1`.
81
+ reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
82
+ Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
83
+ dot-product/softmax to float() when training with mixed precision.
84
+ emit_softmax_over_experts (`bool`, *optional*, defaults to `False`):
85
+ Determines the method of redistributing aggregated tokens in the MoT MLP. By default the model uses the merge weights.
86
+ This flag switches it to taking a softmax over the experts.
87
+ use_discrete_routing (`bool`, *optional*, defaults to `False`):
88
+ Discretize the mixing, sending only to the expert with the highest weight. Inference-only.
89
+
90
+ Example:
91
+
92
+ ```python
93
+ >>> from transformers import MoTConfig, MoTModel
94
+
95
+ >>> # Initializing a MoT configuration
96
+ >>> configuration = MoTConfig()
97
+
98
+ >>> # Initializing a model (with random weights) from the configuration
99
+ >>> model = MoTModel(configuration)
100
+
101
+ >>> # Accessing the model configuration
102
+ >>> configuration = model.config
103
+ ```"""
104
+
105
+ model_type = "mot"
106
+ keys_to_ignore_at_inference = ["past_key_values"]
107
+ attribute_map = {
108
+ "hidden_size": "n_embd",
109
+ "max_position_embeddings": "n_positions",
110
+ "num_attention_heads": "n_head",
111
+ "num_hidden_layers": "n_layer",
112
+ }
113
+
114
+ def __init__(
115
+ self,
116
+ vocab_size=50257,
117
+ n_positions=1024,
118
+ n_embd=768,
119
+ n_layer=12,
120
+ n_head=12,
121
+ n_inner=None,
122
+ n_expert=32,
123
+ group_size=32,
124
+ expert_size=None,
125
+ init_scale=1.0,
126
+ activation_function="gelu_new",
127
+ resid_pdrop=0.1,
128
+ embd_pdrop=0.1,
129
+ attn_pdrop=0.1,
130
+ layer_norm_epsilon=1e-5,
131
+ initializer_range=0.02,
132
+ scale_attn_weights=True,
133
+ use_cache=True,
134
+ bos_token_id=50256,
135
+ eos_token_id=50256,
136
+ scale_attn_by_inverse_layer_idx=False,
137
+ reorder_and_upcast_attn=False,
138
+ emit_softmax_over_experts=False,
139
+ use_discrete_routing=False,
140
+ **kwargs,
141
+ ):
142
+ self.vocab_size = vocab_size
143
+ self.n_positions = n_positions
144
+ self.n_embd = n_embd
145
+ self.n_layer = n_layer
146
+ self.n_head = n_head
147
+ self.n_inner = n_inner
148
+ self.n_expert = n_expert
149
+ self.group_size = group_size
150
+ self.expert_size = expert_size
151
+ self.init_scale = init_scale
152
+ self.activation_function = activation_function
153
+ self.resid_pdrop = resid_pdrop
154
+ self.embd_pdrop = embd_pdrop
155
+ self.attn_pdrop = attn_pdrop
156
+ self.layer_norm_epsilon = layer_norm_epsilon
157
+ self.initializer_range = initializer_range
158
+ self.scale_attn_weights = scale_attn_weights
159
+ self.use_cache = use_cache
160
+ self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
161
+ self.reorder_and_upcast_attn = reorder_and_upcast_attn
162
+ self.emit_softmax_over_experts = emit_softmax_over_experts
163
+ self.use_discrete_routing = use_discrete_routing
164
+
165
+ self.bos_token_id = bos_token_id
166
+ self.eos_token_id = eos_token_id
167
+
168
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98e0ab57892628979b24fcb28fed51eec69eecd6f8c9e6e1153a143f0219560e
3
+ size 2290399320
modeling_mot.py ADDED
@@ -0,0 +1,1621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch MixtureOfTokens model."""
17
+
18
+ import math
19
+ import warnings
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ import torch.utils.checkpoint
25
+ from torch import nn
26
+ from torch.cuda.amp import autocast
27
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
28
+ from torch.nn.init import trunc_normal_
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_outputs import (
32
+ BaseModelOutputWithPastAndCrossAttentions,
33
+ CausalLMOutputWithCrossAttentions,
34
+ QuestionAnsweringModelOutput,
35
+ SequenceClassifierOutputWithPast,
36
+ TokenClassifierOutput,
37
+ )
38
+ from transformers.modeling_utils import PreTrainedModel
39
+ from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
40
+ from transformers.utils import (
41
+ add_code_sample_docstrings,
42
+ add_start_docstrings,
43
+ add_start_docstrings_to_model_forward,
44
+ logging,
45
+ )
46
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
47
+ from .configuration_mot import MoTConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+ _CHECKPOINT_FOR_DOC = "jaszczur/mixture_of_tokens"
53
+ _CONFIG_FOR_DOC = "MoTConfig"
54
+
55
+
56
+ def with_batch_size_alignment(forward_fn):
57
+ def _forward(self, x):
58
+ """assumed ordering (batch, seq_len, dmodel)"""
59
+ size = x.size(self.sparsity_dim)
60
+ if size % self.group_size != 0:
61
+ if self.sparsity_dim == 1:
62
+ x = x.transpose(0, 1)
63
+
64
+ x = self.pad(x)
65
+
66
+ if self.sparsity_dim == 1:
67
+ x = forward_fn(self, x.transpose(0, 1))
68
+ return x[:, :size, :]
69
+ else:
70
+ x = forward_fn(self, x)
71
+ return x[:size, :, :]
72
+ else:
73
+ return forward_fn(self, x)
74
+
75
+ return _forward
76
+
77
+
78
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->MoT
79
+ class MoTAttention(nn.Module):
80
+ def __init__(self, config, is_cross_attention=False, layer_idx=None):
81
+ super().__init__()
82
+ self.config = config
83
+ max_positions = config.max_position_embeddings
84
+ self.register_buffer(
85
+ "bias",
86
+ torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
87
+ 1, 1, max_positions, max_positions
88
+ ),
89
+ persistent=False,
90
+ )
91
+ self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
92
+
93
+ self.embed_dim = config.hidden_size
94
+ self.num_heads = config.num_attention_heads
95
+ self.head_dim = self.embed_dim // self.num_heads
96
+ self.split_size = self.embed_dim
97
+ if self.head_dim * self.num_heads != self.embed_dim:
98
+ raise ValueError(
99
+ f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
100
+ f" {self.num_heads})."
101
+ )
102
+
103
+ self.scale_attn_weights = config.scale_attn_weights
104
+ self.is_cross_attention = is_cross_attention
105
+
106
+ # Layer-wise attention scaling, reordering, and upcasting
107
+ self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
108
+ self.layer_idx = layer_idx
109
+ self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
110
+
111
+ if self.is_cross_attention:
112
+ self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
113
+ self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
114
+ else:
115
+ self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
116
+ self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
117
+
118
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
119
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
120
+ self.is_causal = True
121
+
122
+ self.pruned_heads = set()
123
+
124
+ def prune_heads(self, heads):
125
+ if len(heads) == 0:
126
+ return
127
+ heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
128
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
129
+
130
+ # Prune conv1d layers
131
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
132
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
133
+
134
+ # Update hyper params
135
+ self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
136
+ self.num_heads = self.num_heads - len(heads)
137
+ self.pruned_heads = self.pruned_heads.union(heads)
138
+
139
+ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
140
+ attn_weights = torch.matmul(query, key.transpose(-1, -2))
141
+
142
+ if self.scale_attn_weights:
143
+ attn_weights = attn_weights / torch.full(
144
+ [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
145
+ )
146
+
147
+ # Layer-wise attention scaling
148
+ if self.scale_attn_by_inverse_layer_idx:
149
+ attn_weights = attn_weights / float(self.layer_idx + 1)
150
+
151
+ if not self.is_cross_attention:
152
+ # if only "normal" attention layer implements causal mask
153
+ query_length, key_length = query.size(-2), key.size(-2)
154
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
155
+ mask_value = torch.finfo(attn_weights.dtype).min
156
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
157
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
158
+ mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
159
+ attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
160
+
161
+ if attention_mask is not None:
162
+ # Apply the attention mask
163
+ attn_weights = attn_weights + attention_mask
164
+
165
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
166
+
167
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
168
+ attn_weights = attn_weights.type(value.dtype)
169
+ attn_weights = self.attn_dropout(attn_weights)
170
+
171
+ # Mask heads if we want to
172
+ if head_mask is not None:
173
+ attn_weights = attn_weights * head_mask
174
+
175
+ attn_output = torch.matmul(attn_weights, value)
176
+
177
+ return attn_output, attn_weights
178
+
179
+ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
180
+ # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
181
+ bsz, num_heads, q_seq_len, dk = query.size()
182
+ _, _, k_seq_len, _ = key.size()
183
+
184
+ # Preallocate attn_weights for `baddbmm`
185
+ attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
186
+
187
+ # Compute Scale Factor
188
+ scale_factor = 1.0
189
+ if self.scale_attn_weights:
190
+ scale_factor /= float(value.size(-1)) ** 0.5
191
+
192
+ if self.scale_attn_by_inverse_layer_idx:
193
+ scale_factor /= float(self.layer_idx + 1)
194
+
195
+ # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
196
+ with autocast(enabled=False):
197
+ q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
198
+ attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
199
+ attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
200
+
201
+ if not self.is_cross_attention:
202
+ # if only "normal" attention layer implements causal mask
203
+ query_length, key_length = query.size(-2), key.size(-2)
204
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
205
+ mask_value = torch.finfo(attn_weights.dtype).min
206
+ # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
207
+ # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
208
+ mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
209
+ attn_weights = torch.where(causal_mask, attn_weights, mask_value)
210
+
211
+ if attention_mask is not None:
212
+ # Apply the attention mask
213
+ attn_weights = attn_weights + attention_mask
214
+
215
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
216
+
217
+ # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
218
+ if attn_weights.dtype != torch.float32:
219
+ raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
220
+ attn_weights = attn_weights.type(value.dtype)
221
+ attn_weights = self.attn_dropout(attn_weights)
222
+
223
+ # Mask heads if we want to
224
+ if head_mask is not None:
225
+ attn_weights = attn_weights * head_mask
226
+
227
+ attn_output = torch.matmul(attn_weights, value)
228
+
229
+ return attn_output, attn_weights
230
+
231
+ def _split_heads(self, tensor, num_heads, attn_head_size):
232
+ """
233
+ Splits hidden_size dim into attn_head_size and num_heads
234
+ """
235
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
236
+ tensor = tensor.view(new_shape)
237
+ return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
238
+
239
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
240
+ """
241
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
242
+ """
243
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
244
+ new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
245
+ return tensor.view(new_shape)
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
250
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
251
+ attention_mask: Optional[torch.FloatTensor] = None,
252
+ head_mask: Optional[torch.FloatTensor] = None,
253
+ encoder_hidden_states: Optional[torch.Tensor] = None,
254
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
255
+ use_cache: Optional[bool] = False,
256
+ output_attentions: Optional[bool] = False,
257
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
258
+ if encoder_hidden_states is not None:
259
+ if not hasattr(self, "q_attn"):
260
+ raise ValueError(
261
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
262
+ "Please make sure to instantiate class with `MoTAttention(..., is_cross_attention=True)`."
263
+ )
264
+
265
+ query = self.q_attn(hidden_states)
266
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
267
+ attention_mask = encoder_attention_mask
268
+ else:
269
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
270
+
271
+ query = self._split_heads(query, self.num_heads, self.head_dim)
272
+ key = self._split_heads(key, self.num_heads, self.head_dim)
273
+ value = self._split_heads(value, self.num_heads, self.head_dim)
274
+
275
+ if layer_past is not None:
276
+ past_key, past_value = layer_past
277
+ key = torch.cat((past_key, key), dim=-2)
278
+ value = torch.cat((past_value, value), dim=-2)
279
+
280
+ if use_cache is True:
281
+ present = (key, value)
282
+ else:
283
+ present = None
284
+
285
+ if self.reorder_and_upcast_attn:
286
+ attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
287
+ else:
288
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
289
+
290
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
291
+ attn_output = self.c_proj(attn_output)
292
+ attn_output = self.resid_dropout(attn_output)
293
+
294
+ outputs = (attn_output, present)
295
+ if output_attentions:
296
+ outputs += (attn_weights,)
297
+
298
+ return outputs # a, present, (attentions)
299
+
300
+
301
+ class MoTMLP(nn.Module):
302
+ r"""
303
+ Implementation of the Mixture of Tokens Sparse MLP module.
304
+ """
305
+
306
+ def __init__(self, inner_dim: int, config: MoTConfig, sparsity_dim: int = 0, init_type: str = "kaiming_uniform"):
307
+ super().__init__()
308
+
309
+ self.d_model: int = config.n_embd
310
+ self.d_ff: int = config.n_inner if inner_dim is None else inner_dim
311
+ self.n_expert: int = config.n_expert
312
+ self.group_size: int = config.group_size
313
+ self.sparsity_dim: int = sparsity_dim
314
+ self.expert_size: int = config.expert_size
315
+ self.temperature: float = config.temperature
316
+ self.act = ACT2FN[config.activation_function]
317
+
318
+ self.init_type: str = init_type
319
+ self.init_scale: float = config.init_scale
320
+
321
+ self.emit_softmax_over_experts: bool = config.emit_softmax_over_experts
322
+ self.use_discrete_routing: bool = config.use_discrete_routing
323
+
324
+ if self.n_expert is not None:
325
+ if self.d_ff % self.n_expert:
326
+ self.d_ff += self.n_expert - (self.d_ff % self.n_expert)
327
+ warnings.warn("d_ff should be divisible by n_expert, padding d_ff to be divisible by n_expert")
328
+ self.expert_size = self.d_ff // self.n_expert
329
+ elif self.expert_size is not None:
330
+ if self.d_ff % self.expert_size:
331
+ self.d_ff += self.expert_size - (self.d_ff % self.expert_size)
332
+ warnings.warn("d_ff should be divisible by expert_size, padding d_ff to be divisible by expert_size")
333
+ self.n_expert = self.d_ff // self.expert_size
334
+ else:
335
+ raise ValueError("Either expert_size or n_expert should be provided")
336
+
337
+ self.lin1 = nn.Parameter(
338
+ self.get_init_weight(
339
+ (self.n_expert, self.d_model, self.expert_size),
340
+ fan_in=self.d_model,
341
+ init_type=self.init_type,
342
+ scale=self.init_scale,
343
+ )
344
+ )
345
+
346
+ self.lin2 = nn.Parameter(
347
+ self.get_init_weight(
348
+ (self.n_expert, self.expert_size, self.d_model),
349
+ fan_in=self.expert_size,
350
+ init_type=self.init_type,
351
+ scale=self.init_scale,
352
+ )
353
+ )
354
+
355
+ self.controller = nn.Parameter(
356
+ self.get_init_weight(
357
+ (self.d_model, self.n_expert),
358
+ fan_in=self.d_model,
359
+ init_type=self.init_type,
360
+ scale=self.init_scale,
361
+ )
362
+ )
363
+ self.dropout = nn.Dropout(config.resid_pdrop)
364
+
365
+ @staticmethod
366
+ def argmax_one_hot(x: torch.Tensor, dim: int):
367
+ max_values, _ = x.max(dim=dim, keepdim=True)
368
+ return torch.where(
369
+ condition=x == max_values,
370
+ input=torch.Tensor([1.0]).to(dtype=x.dtype, device=x.device),
371
+ other=torch.Tensor([0.0]).to(dtype=x.dtype, device=x.device),
372
+ ) # potentially make it the value itself? torch.where(x == max_values, x, 0.0)
373
+
374
+ def get_init_weight(self, shape, fan_in, init_type, scale, dtype=torch.float32):
375
+ if init_type == "kaiming_uniform":
376
+ return self.init_kaiming_uniform(shape=shape, fan_in=fan_in, scale=scale, dtype=dtype)
377
+ elif init_type == "truncated_normal":
378
+ return self.init_truncated_normal(shape=shape, fan_in=fan_in, scale=scale, dtype=dtype)
379
+ else:
380
+ raise ValueError(f"Unknown init_type: {init_type}")
381
+
382
+ @staticmethod
383
+ def init_kaiming_uniform(shape, fan_in, scale, dtype=torch.float32):
384
+ range_ = scale * (3 / fan_in) ** 0.5
385
+ return torch.zeros(shape, dtype=dtype).uniform_(-range_, range_)
386
+
387
+ @staticmethod
388
+ def init_truncated_normal(shape, fan_in, scale, dtype=torch.float32):
389
+ std = (scale / fan_in) ** 0.5
390
+ low = -2 * scale
391
+ high = 2 * scale
392
+ t = torch.zeros(shape, dtype=dtype)
393
+ return trunc_normal_(t, mean=0.0, std=std, a=low, b=high)
394
+
395
+ @staticmethod
396
+ def stable_softmax_temperature(x: torch.Tensor, temperature: float, dim: int = -1) -> torch.Tensor:
397
+ return F.softmax(x / temperature, dim=dim)
398
+
399
+ def pad(self, x):
400
+ size = x.size(0)
401
+ ceiling = torch.ceil(torch.tensor(size / self.group_size).float())
402
+ new_batch_size = self.group_size * ceiling.int().item()
403
+ padding_size = new_batch_size - size
404
+ logger.debug("Padding batch size from %d to %d", size, new_batch_size)
405
+
406
+ # Create a zero sequence for padding
407
+ zero_sequence = torch.zeros_like(x[0:1])
408
+ padding_sequences = zero_sequence.repeat(padding_size, 1, 1)
409
+
410
+ return torch.cat([x, padding_sequences], dim=0)
411
+
412
+ @with_batch_size_alignment
413
+ def forward(self, x):
414
+ x = self.group_tokens(x)
415
+ merge_weights, emit_weights = self.calculate_mixed_tokens_with_weights(x)
416
+ x = self.merge_map_emit(x, merge_weights, emit_weights)
417
+ x = self.redistribute_tokens(x)
418
+ x = self.dropout(x)
419
+ return x
420
+
421
+ def group_tokens(self, x):
422
+ """
423
+ Reshape code so the axis to split into groups is on position 1, and then group over said axis.
424
+ e.g.:
425
+ - if we group tokens from different sequences in a batch (sparsity = 0), we need to put the batch dimension to position 1.
426
+ - if we group tokens within one sequence, the dimension to split into groups is already on position 1, hence we leave it as is.
427
+
428
+ free_dimension is the dimension on position 0 after reshape
429
+ split_dimension is the dimension on position 1 - the one to split into groups
430
+
431
+ :param x: normal input tensor of shape (batch, seq_len, dmodel)
432
+ :return: x of shape (free_dimension, split_dimension // group_size, group_size , dmodel)
433
+ """
434
+ assert len(x.shape) == 3, "incorrect shape of a tensor, expected a 3D tensor"
435
+ assert (
436
+ x.size(-1) == self.d_model
437
+ ), f"expected the last dimension of input tensor to be d_model = {self.d_model}"
438
+
439
+ if self.sparsity_dim == 0:
440
+ x = x.transpose(0, 1)
441
+ elif self.sparsity_dim != 1:
442
+ raise NotImplementedError
443
+
444
+ free_dimension = x.size(1)
445
+ assert (
446
+ free_dimension % self.group_size == 0
447
+ ), f"free dimension = {free_dimension} should be divisible by group size = {self.group_size}"
448
+
449
+ x = x.view(x.size(0), -1, self.group_size, self.d_model)
450
+ return x
451
+
452
+ def redistribute_tokens(self, x):
453
+ """
454
+ An inverse operation to group_tokens.
455
+ """
456
+ assert len(x.shape) == 4, "incorrect shape of a tensor, expected a 4D tensor"
457
+
458
+ x = x.view(x.size(0), -1, self.d_model)
459
+ if self.sparsity_dim == 0:
460
+ x = x.transpose(0, 1)
461
+ elif self.sparsity_dim != 1:
462
+ raise NotImplementedError
463
+
464
+ return x
465
+
466
+ def calculate_mixed_tokens_with_weights(self, x):
467
+ """
468
+ This function calculates merge and emit weights based on the input tensor, using a controller matrix.
469
+ The merge weights determine the aggregation of tokens within a group, and emit weights govern the redistribution
470
+ of the aggregated token back to the original tokens. Temperature scaling is applied to the logits, and optional
471
+ discrete routing can be used to obtain one-hot representations of the weights.
472
+ """
473
+ # shape of x is (free_dimension, split_dimension // group_size, group_size, dmodel)
474
+ merge_logits = torch.matmul(x, self.controller)
475
+ # self.update_cache_for_logging("merge_logits", merge_logits)
476
+
477
+ # shape of merge_logits is (free_dimension, aggr_dimension // group_size, group_size, n_expert)
478
+ temp_merge = self.temperature
479
+ temp_emit = self.temperature
480
+
481
+ merge_softmax_dim = -2
482
+ emit_softmax_dim = -1 if self.emit_softmax_over_experts else -2
483
+
484
+ merge_weights = self.stable_softmax_temperature(merge_logits, temp_merge, dim=merge_softmax_dim)
485
+
486
+ # by default we use the same weights for emitting and merging, but if the temperature is learnable or we want to take softmax over experts for emitting, we will use different weights
487
+ if isinstance(temp_merge, torch.nn.Parameter) or self.emit_softmax_over_experts:
488
+ emit_weights = self.stable_softmax_temperature(merge_logits, temp_emit, dim=emit_softmax_dim)
489
+ else:
490
+ emit_weights = merge_weights
491
+
492
+ if self.use_discrete_routing:
493
+ merge_weights = self.argmax_one_hot(merge_weights, dim=merge_softmax_dim)
494
+ emit_weights = self.argmax_one_hot(emit_weights, dim=emit_softmax_dim)
495
+ return merge_weights, emit_weights
496
+
497
+ def merge_map_emit(self, x, merge_weights, emit_weights):
498
+ """
499
+ :param x: input reshaped to (free_dimension, split_dimension // group_size, group_size, dmodel)
500
+ :param merge_weights: weights for merging tokens within a group, shape (free_dimension, split_dimension // group_size, group_size, n_expert)
501
+ :param emit_weights: weights for emitting tokens within a group, shape (free_dimension, split_dimension // group_size, group_size, n_expert)
502
+ :return: tensor of token updates of shape (free_dimension, split_dimension // group_size, group_size, dmodel)
503
+ """
504
+ x = torch.matmul(
505
+ merge_weights.transpose(-1, -2),
506
+ x,
507
+ )
508
+ # x shape is (free_dimension, split_dimension // group_size, n_expert, dmodel) ||| lin1 shape is (n_expert, dmodel, expert_size)
509
+ x = torch.bmm(x.view(-1, self.n_expert, self.d_model).transpose(0, 1), self.lin1)
510
+ x = self.act(x)
511
+ # x shape is (n_expert, free_dimension * aggr_dimension // group_size, expert_size) ||| lin2 shape is (n_expert, expert_size, dmodel)
512
+ x = torch.bmm(x, self.lin2)
513
+ # x shape is (n_expert, free_dimension * aggr_dimension // group_size, dmodel)
514
+
515
+ # merge_weights shape is (free_dimension, aggr_dimension // group_size, group_size, n_expert)
516
+ # view x to be (n_expert, free_dimension, aggr_dimension // group_size, dmodel)
517
+ # permute it to be (free_dimension, aggr_dimension // group_size, n_expert, dmodel)
518
+ x = torch.matmul(
519
+ emit_weights,
520
+ x.view(x.size(0), emit_weights.size(0), -1, self.d_model).permute(1, 2, 0, 3),
521
+ )
522
+
523
+ return x
524
+
525
+
526
+ MoT_ATTENTION_CLASSES = {
527
+ "eager": MoTAttention,
528
+ }
529
+
530
+
531
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->MoT
532
+ class MoTBlock(nn.Module):
533
+ def __init__(self, config, layer_idx=None):
534
+ super().__init__()
535
+ hidden_size = config.hidden_size
536
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
537
+ attention_class = MoT_ATTENTION_CLASSES[config._attn_implementation]
538
+
539
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
540
+ self.attn = attention_class(config=config, layer_idx=layer_idx)
541
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
542
+
543
+ if config.add_cross_attention:
544
+ self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
545
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
546
+
547
+ self.mlp = MoTMLP(inner_dim, config)
548
+
549
+ def forward(
550
+ self,
551
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
552
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
553
+ attention_mask: Optional[torch.FloatTensor] = None,
554
+ head_mask: Optional[torch.FloatTensor] = None,
555
+ encoder_hidden_states: Optional[torch.Tensor] = None,
556
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
557
+ use_cache: Optional[bool] = False,
558
+ output_attentions: Optional[bool] = False,
559
+ ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
560
+ residual = hidden_states
561
+ hidden_states = self.ln_1(hidden_states)
562
+ attn_outputs = self.attn(
563
+ hidden_states,
564
+ layer_past=layer_past,
565
+ attention_mask=attention_mask,
566
+ head_mask=head_mask,
567
+ use_cache=use_cache,
568
+ output_attentions=output_attentions,
569
+ )
570
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
571
+ outputs = attn_outputs[1:]
572
+ # residual connection
573
+ hidden_states = attn_output + residual
574
+
575
+ if encoder_hidden_states is not None:
576
+ # add one self-attention block for cross-attention
577
+ if not hasattr(self, "crossattention"):
578
+ raise ValueError(
579
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
580
+ "cross-attention layers by setting `config.add_cross_attention=True`"
581
+ )
582
+ residual = hidden_states
583
+ hidden_states = self.ln_cross_attn(hidden_states)
584
+ cross_attn_outputs = self.crossattention(
585
+ hidden_states,
586
+ attention_mask=attention_mask,
587
+ head_mask=head_mask,
588
+ encoder_hidden_states=encoder_hidden_states,
589
+ encoder_attention_mask=encoder_attention_mask,
590
+ output_attentions=output_attentions,
591
+ )
592
+ attn_output = cross_attn_outputs[0]
593
+ # residual connection
594
+ hidden_states = residual + attn_output
595
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
596
+
597
+ residual = hidden_states
598
+ hidden_states = self.ln_2(hidden_states)
599
+ feed_forward_hidden_states = self.mlp(hidden_states)
600
+ # residual connection
601
+ hidden_states = residual + feed_forward_hidden_states
602
+
603
+ if use_cache:
604
+ outputs = (hidden_states,) + outputs
605
+ else:
606
+ outputs = (hidden_states,) + outputs[1:]
607
+
608
+ return outputs # hidden_states, present, (attentions, cross_attentions)
609
+
610
+
611
+ # Mostly copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel with GPT2->MoT,gpt2->mot,OpenAI GPT-2->MixtureOfTokens, but removed references to TF and FlashAttention2
612
+ class MoTPreTrainedModel(PreTrainedModel):
613
+ """
614
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
615
+ models.
616
+ """
617
+
618
+ config_class = MoTConfig
619
+ base_model_prefix = "transformer"
620
+ is_parallelizable = True
621
+ supports_gradient_checkpointing = True
622
+ _no_split_modules = ["MoTBlock"]
623
+ _skip_keys_device_placement = "past_key_values"
624
+
625
+ def __init__(self, *inputs, **kwargs):
626
+ super().__init__(*inputs, **kwargs)
627
+
628
+ # Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._init_weights with GPT2->MoT,gpt2->mot,OpenAI GPT-2->MixtureOfTokens
629
+ def _init_weights(self, module):
630
+ """Initialize the weights."""
631
+ if isinstance(module, (nn.Linear, Conv1D)):
632
+ # Slightly different from the TF version which uses truncated_normal for initialization
633
+ # cf https://github.com/pytorch/pytorch/pull/5617
634
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
635
+ if module.bias is not None:
636
+ module.bias.data.zero_()
637
+ elif isinstance(module, nn.Embedding):
638
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
639
+ if module.padding_idx is not None:
640
+ module.weight.data[module.padding_idx].zero_()
641
+ elif isinstance(module, nn.LayerNorm):
642
+ module.bias.data.zero_()
643
+ module.weight.data.fill_(1.0)
644
+
645
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
646
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
647
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
648
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
649
+ #
650
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
651
+ for name, p in module.named_parameters():
652
+ if name == "c_proj.weight":
653
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
654
+ p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
655
+
656
+
657
+ MOT_START_DOCSTRING = r"""
658
+
659
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
660
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
661
+ etc.)
662
+
663
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
664
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
665
+ and behavior.
666
+
667
+ Parameters:
668
+ config ([`MoTConfig`]): Model configuration class with all the parameters of the model.
669
+ Initializing with a config file does not load the weights associated with the model, only the
670
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
671
+ """
672
+
673
+ MOT_INPUTS_DOCSTRING = r"""
674
+ Args:
675
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
676
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
677
+ `past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
678
+ sequence tokens in the vocabulary.
679
+
680
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
681
+ `input_ids`.
682
+
683
+ Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
684
+ [`PreTrainedTokenizer.__call__`] for details.
685
+
686
+ [What are input IDs?](../glossary#input-ids)
687
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
688
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
689
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
690
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
691
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
692
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
693
+
694
+ - 1 for tokens that are **not masked**,
695
+ - 0 for tokens that are **masked**.
696
+
697
+ If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
698
+ `past_key_values`. In other words, the `attention_mask` always has to have the length:
699
+ `len(past_key_values) + len(input_ids)`
700
+
701
+ [What are attention masks?](../glossary#attention-mask)
702
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
703
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
704
+ 1]`:
705
+
706
+ - 0 corresponds to a *sentence A* token,
707
+ - 1 corresponds to a *sentence B* token.
708
+
709
+ [What are token type IDs?](../glossary#token-type-ids)
710
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
711
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
712
+ config.max_position_embeddings - 1]`.
713
+
714
+ [What are position IDs?](../glossary#position-ids)
715
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
716
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
717
+
718
+ - 1 indicates the head is **not masked**,
719
+ - 0 indicates the head is **masked**.
720
+
721
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
722
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
723
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
724
+ model's internal embedding lookup matrix.
725
+
726
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
727
+ `past_key_values`).
728
+ use_cache (`bool`, *optional*):
729
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
730
+ `past_key_values`).
731
+ output_attentions (`bool`, *optional*):
732
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
733
+ tensors for more detail.
734
+ output_hidden_states (`bool`, *optional*):
735
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
736
+ more detail.
737
+ return_dict (`bool`, *optional*):
738
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
739
+ """
740
+ PARALLELIZE_DOCSTRING = r"""
741
+ This is an experimental feature and is a subject to change at a moment's notice.
742
+
743
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
744
+ it will evenly distribute blocks across all devices.
745
+
746
+ Args:
747
+ device_map (`Dict[int, list]`, optional, defaults to None):
748
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
749
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
750
+ have fewer attention modules mapped to it than other devices. For reference, the mot models have the
751
+ following number of attention modules:
752
+
753
+ - mot: 12
754
+ - mot-medium: 24
755
+ - mot-large: 36
756
+ - mot-xl: 48
757
+
758
+ Example:
759
+
760
+ ```python
761
+ # Here is an example of a device map on a machine with 4 GPUs using mot-xl, which has a total of 48 attention modules:
762
+ model = MoTLMHeadModel.from_pretrained("jaszczur/mixture_of_tokens")
763
+ device_map = {
764
+ 0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
765
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
766
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
767
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
768
+ }
769
+ model.parallelize(device_map)
770
+ ```
771
+ """
772
+ DEPARALLELIZE_DOCSTRING = r"""
773
+ Moves the model to cpu from a model parallel state.
774
+
775
+ Example:
776
+
777
+ ```python
778
+ # On a 4 GPU machine with mot-large:
779
+ model = MoTLMHeadModel.from_pretrained("jaszczur/mixture_of_tokens")
780
+ device_map = {
781
+ 0: [0, 1, 2, 3, 4, 5, 6, 7],
782
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
783
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
784
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
785
+ }
786
+ model.parallelize(device_map) # Splits the model across several devices
787
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
788
+ ```
789
+ """
790
+
791
+
792
+ @add_start_docstrings(
793
+ "The bare MOT Model transformer outputting raw hidden-states without any specific head on top.",
794
+ MOT_START_DOCSTRING,
795
+ )
796
+ class MoTModel(MoTPreTrainedModel):
797
+ def __init__(self, config):
798
+ super().__init__(config)
799
+
800
+ self.embed_dim = config.hidden_size
801
+
802
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
803
+ self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
804
+
805
+ self.drop = nn.Dropout(config.embd_pdrop)
806
+ self.h = nn.ModuleList([MoTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
807
+ self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
808
+
809
+ # Model parallel
810
+ self.model_parallel = False
811
+ self.device_map = None
812
+ self.gradient_checkpointing = False
813
+
814
+ # Initialize weights and apply final processing
815
+ self.post_init()
816
+
817
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
818
+ def parallelize(self, device_map=None):
819
+ # Check validity of device_map
820
+ warnings.warn(
821
+ "`MoTModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
822
+ " model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
823
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
824
+ " ...}",
825
+ FutureWarning,
826
+ )
827
+ self.device_map = (
828
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
829
+ )
830
+ assert_device_map(self.device_map, len(self.h))
831
+ self.model_parallel = True
832
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
833
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
834
+ self.wte = self.wte.to(self.first_device)
835
+ self.wpe = self.wpe.to(self.first_device)
836
+ # Load onto devices
837
+ for k, v in self.device_map.items():
838
+ for block in v:
839
+ cuda_device = "cuda:" + str(k)
840
+ self.h[block] = self.h[block].to(cuda_device)
841
+ # ln_f to last
842
+ self.ln_f = self.ln_f.to(self.last_device)
843
+
844
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
845
+ def deparallelize(self):
846
+ warnings.warn(
847
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
848
+ FutureWarning,
849
+ )
850
+ self.model_parallel = False
851
+ self.device_map = None
852
+ self.first_device = "cpu"
853
+ self.last_device = "cpu"
854
+ self.wte = self.wte.to("cpu")
855
+ self.wpe = self.wpe.to("cpu")
856
+ for index in range(len(self.h)):
857
+ self.h[index] = self.h[index].to("cpu")
858
+ self.ln_f = self.ln_f.to("cpu")
859
+ torch.cuda.empty_cache()
860
+
861
+ def get_input_embeddings(self):
862
+ return self.wte
863
+
864
+ def set_input_embeddings(self, new_embeddings):
865
+ self.wte = new_embeddings
866
+
867
+ def _prune_heads(self, heads_to_prune):
868
+ """
869
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
870
+ """
871
+ for layer, heads in heads_to_prune.items():
872
+ self.h[layer].attn.prune_heads(heads)
873
+
874
+ @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
875
+ @add_code_sample_docstrings(
876
+ checkpoint=_CHECKPOINT_FOR_DOC,
877
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
878
+ config_class=_CONFIG_FOR_DOC,
879
+ )
880
+ def forward(
881
+ self,
882
+ input_ids: Optional[torch.LongTensor] = None,
883
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
884
+ attention_mask: Optional[torch.FloatTensor] = None,
885
+ token_type_ids: Optional[torch.LongTensor] = None,
886
+ position_ids: Optional[torch.LongTensor] = None,
887
+ head_mask: Optional[torch.FloatTensor] = None,
888
+ inputs_embeds: Optional[torch.FloatTensor] = None,
889
+ encoder_hidden_states: Optional[torch.Tensor] = None,
890
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
891
+ use_cache: Optional[bool] = None,
892
+ output_attentions: Optional[bool] = None,
893
+ output_hidden_states: Optional[bool] = None,
894
+ return_dict: Optional[bool] = None,
895
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
896
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
897
+ output_hidden_states = (
898
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
899
+ )
900
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
901
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
902
+
903
+ if input_ids is not None and inputs_embeds is not None:
904
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
905
+ elif input_ids is not None:
906
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
907
+ input_shape = input_ids.size()
908
+ input_ids = input_ids.view(-1, input_shape[-1])
909
+ batch_size = input_ids.shape[0]
910
+ elif inputs_embeds is not None:
911
+ input_shape = inputs_embeds.size()[:-1]
912
+ batch_size = inputs_embeds.shape[0]
913
+ else:
914
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
915
+
916
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
917
+
918
+ if token_type_ids is not None:
919
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
920
+
921
+ if past_key_values is None:
922
+ past_length = 0
923
+ past_key_values = tuple([None] * len(self.h))
924
+ else:
925
+ past_length = past_key_values[0][0].size(-2)
926
+ if position_ids is None:
927
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
928
+ position_ids = position_ids.unsqueeze(0)
929
+
930
+ # MoTAttention mask.
931
+ if attention_mask is not None:
932
+ if batch_size <= 0:
933
+ raise ValueError("batch_size has to be defined and > 0")
934
+ attention_mask = attention_mask.view(batch_size, -1)
935
+ # We create a 3D attention mask from a 2D tensor mask.
936
+ # Sizes are [batch_size, 1, 1, to_seq_length]
937
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
938
+ # this attention mask is more simple than the triangular masking of causal attention
939
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
940
+ attention_mask = attention_mask[:, None, None, :]
941
+
942
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
943
+ # masked positions, this operation will create a tensor which is 0.0 for
944
+ # positions we want to attend and the dtype's smallest value for masked positions.
945
+ # Since we are adding it to the raw scores before the softmax, this is
946
+ # effectively the same as removing these entirely.
947
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
948
+ attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
949
+
950
+ # If a 2D or 3D attention mask is provided for the cross-attention
951
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
952
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
953
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
954
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
955
+ if encoder_attention_mask is None:
956
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
957
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
958
+ else:
959
+ encoder_attention_mask = None
960
+
961
+ # Prepare head mask if needed
962
+ # 1.0 in head_mask indicate we keep the head
963
+ # attention_probs has shape bsz x n_heads x N x N
964
+ # head_mask has shape n_layer x batch x n_heads x N x N
965
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
966
+
967
+ if inputs_embeds is None:
968
+ inputs_embeds = self.wte(input_ids)
969
+ position_embeds = self.wpe(position_ids)
970
+ hidden_states = inputs_embeds + position_embeds
971
+
972
+ if token_type_ids is not None:
973
+ token_type_embeds = self.wte(token_type_ids)
974
+ hidden_states = hidden_states + token_type_embeds
975
+
976
+ hidden_states = self.drop(hidden_states)
977
+
978
+ output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
979
+
980
+ if self.gradient_checkpointing and self.training:
981
+ if use_cache:
982
+ logger.warning_once(
983
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
984
+ )
985
+ use_cache = False
986
+
987
+ presents = () if use_cache else None
988
+ all_self_attentions = () if output_attentions else None
989
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
990
+ all_hidden_states = () if output_hidden_states else None
991
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
992
+ # Model parallel
993
+ if self.model_parallel:
994
+ torch.cuda.set_device(hidden_states.device)
995
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
996
+ if layer_past is not None:
997
+ layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
998
+ # Ensure that attention_mask is always on the same device as hidden_states
999
+ if attention_mask is not None:
1000
+ attention_mask = attention_mask.to(hidden_states.device)
1001
+ if isinstance(head_mask, torch.Tensor):
1002
+ head_mask = head_mask.to(hidden_states.device)
1003
+ if output_hidden_states:
1004
+ all_hidden_states = all_hidden_states + (hidden_states,)
1005
+
1006
+ if self.gradient_checkpointing and self.training:
1007
+ outputs = self._gradient_checkpointing_func(
1008
+ block.__call__,
1009
+ hidden_states,
1010
+ None,
1011
+ attention_mask,
1012
+ head_mask[i],
1013
+ encoder_hidden_states,
1014
+ encoder_attention_mask,
1015
+ use_cache,
1016
+ output_attentions,
1017
+ )
1018
+ else:
1019
+ outputs = block(
1020
+ hidden_states,
1021
+ layer_past=layer_past,
1022
+ attention_mask=attention_mask,
1023
+ head_mask=head_mask[i],
1024
+ encoder_hidden_states=encoder_hidden_states,
1025
+ encoder_attention_mask=encoder_attention_mask,
1026
+ use_cache=use_cache,
1027
+ output_attentions=output_attentions,
1028
+ )
1029
+
1030
+ hidden_states = outputs[0]
1031
+ if use_cache is True:
1032
+ presents = presents + (outputs[1],)
1033
+
1034
+ if output_attentions:
1035
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
1036
+ if self.config.add_cross_attention:
1037
+ all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
1038
+
1039
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1040
+ if self.model_parallel:
1041
+ for k, v in self.device_map.items():
1042
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1043
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1044
+
1045
+ hidden_states = self.ln_f(hidden_states)
1046
+
1047
+ hidden_states = hidden_states.view(output_shape)
1048
+ # Add last hidden state
1049
+ if output_hidden_states:
1050
+ all_hidden_states = all_hidden_states + (hidden_states,)
1051
+
1052
+ if not return_dict:
1053
+ return tuple(
1054
+ v
1055
+ for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
1056
+ if v is not None
1057
+ )
1058
+
1059
+ return BaseModelOutputWithPastAndCrossAttentions(
1060
+ last_hidden_state=hidden_states,
1061
+ past_key_values=presents,
1062
+ hidden_states=all_hidden_states,
1063
+ attentions=all_self_attentions,
1064
+ cross_attentions=all_cross_attentions,
1065
+ )
1066
+
1067
+
1068
+ @add_start_docstrings(
1069
+ """
1070
+ The MOT Model transformer with a language modeling head on top (linear layer with weights tied to the input
1071
+ embeddings).
1072
+ """,
1073
+ MOT_START_DOCSTRING,
1074
+ )
1075
+ class MoTLMHeadModel(MoTPreTrainedModel):
1076
+ _tied_weights_keys = ["lm_head.weight"]
1077
+
1078
+ def __init__(self, config):
1079
+ super().__init__(config)
1080
+ self.transformer = MoTModel(config)
1081
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
1082
+
1083
+ # Model parallel
1084
+ self.model_parallel = False
1085
+ self.device_map = None
1086
+
1087
+ # Initialize weights and apply final processing
1088
+ self.post_init()
1089
+
1090
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
1091
+ def parallelize(self, device_map=None):
1092
+ warnings.warn(
1093
+ "`MoTLMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1094
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1095
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
1096
+ " 0, 'transformer.h.1': 1, ...}",
1097
+ FutureWarning,
1098
+ )
1099
+ self.device_map = (
1100
+ get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
1101
+ if device_map is None
1102
+ else device_map
1103
+ )
1104
+ assert_device_map(self.device_map, len(self.transformer.h))
1105
+ self.transformer.parallelize(self.device_map)
1106
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
1107
+ self.model_parallel = True
1108
+
1109
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
1110
+ def deparallelize(self):
1111
+ warnings.warn(
1112
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1113
+ FutureWarning,
1114
+ )
1115
+ self.transformer.deparallelize()
1116
+ self.transformer = self.transformer.to("cpu")
1117
+ self.lm_head = self.lm_head.to("cpu")
1118
+ self.model_parallel = False
1119
+ torch.cuda.empty_cache()
1120
+
1121
+ def get_output_embeddings(self):
1122
+ return self.lm_head
1123
+
1124
+ def set_output_embeddings(self, new_embeddings):
1125
+ self.lm_head = new_embeddings
1126
+
1127
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
1128
+ token_type_ids = kwargs.get("token_type_ids", None)
1129
+ # Omit tokens covered by past_key_values
1130
+ if past_key_values:
1131
+ past_length = past_key_values[0][0].shape[2]
1132
+
1133
+ # Some generation methods already pass only the last input ID
1134
+ if input_ids.shape[1] > past_length:
1135
+ remove_prefix_length = past_length
1136
+ else:
1137
+ # Default to old behavior: keep only final ID
1138
+ remove_prefix_length = input_ids.shape[1] - 1
1139
+
1140
+ input_ids = input_ids[:, remove_prefix_length:]
1141
+ if token_type_ids is not None:
1142
+ token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
1143
+
1144
+ attention_mask = kwargs.get("attention_mask", None)
1145
+ position_ids = kwargs.get("position_ids", None)
1146
+
1147
+ if attention_mask is not None and position_ids is None:
1148
+ # create position_ids on the fly for batch generation
1149
+ position_ids = attention_mask.long().cumsum(-1) - 1
1150
+ position_ids.masked_fill_(attention_mask == 0, 1)
1151
+ if past_key_values:
1152
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1153
+ else:
1154
+ position_ids = None
1155
+
1156
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1157
+ if inputs_embeds is not None and past_key_values is None:
1158
+ model_inputs = {"inputs_embeds": inputs_embeds}
1159
+ else:
1160
+ model_inputs = {"input_ids": input_ids}
1161
+
1162
+ model_inputs.update(
1163
+ {
1164
+ "past_key_values": past_key_values,
1165
+ "use_cache": kwargs.get("use_cache"),
1166
+ "position_ids": position_ids,
1167
+ "attention_mask": attention_mask,
1168
+ "token_type_ids": token_type_ids,
1169
+ }
1170
+ )
1171
+
1172
+ return model_inputs
1173
+
1174
+ @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
1175
+ @add_code_sample_docstrings(
1176
+ checkpoint=_CHECKPOINT_FOR_DOC,
1177
+ output_type=CausalLMOutputWithCrossAttentions,
1178
+ config_class=_CONFIG_FOR_DOC,
1179
+ )
1180
+ def forward(
1181
+ self,
1182
+ input_ids: Optional[torch.LongTensor] = None,
1183
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1184
+ attention_mask: Optional[torch.FloatTensor] = None,
1185
+ token_type_ids: Optional[torch.LongTensor] = None,
1186
+ position_ids: Optional[torch.LongTensor] = None,
1187
+ head_mask: Optional[torch.FloatTensor] = None,
1188
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1189
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1190
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1191
+ labels: Optional[torch.LongTensor] = None,
1192
+ use_cache: Optional[bool] = None,
1193
+ output_attentions: Optional[bool] = None,
1194
+ output_hidden_states: Optional[bool] = None,
1195
+ return_dict: Optional[bool] = None,
1196
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1197
+ r"""
1198
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1199
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1200
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1201
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1202
+ """
1203
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1204
+
1205
+ transformer_outputs = self.transformer(
1206
+ input_ids,
1207
+ past_key_values=past_key_values,
1208
+ attention_mask=attention_mask,
1209
+ token_type_ids=token_type_ids,
1210
+ position_ids=position_ids,
1211
+ head_mask=head_mask,
1212
+ inputs_embeds=inputs_embeds,
1213
+ encoder_hidden_states=encoder_hidden_states,
1214
+ encoder_attention_mask=encoder_attention_mask,
1215
+ use_cache=use_cache,
1216
+ output_attentions=output_attentions,
1217
+ output_hidden_states=output_hidden_states,
1218
+ return_dict=return_dict,
1219
+ )
1220
+ hidden_states = transformer_outputs[0]
1221
+
1222
+ # Set device for model parallelism
1223
+ if self.model_parallel:
1224
+ torch.cuda.set_device(self.transformer.first_device)
1225
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
1226
+
1227
+ lm_logits = self.lm_head(hidden_states)
1228
+
1229
+ loss = None
1230
+ if labels is not None:
1231
+ # move labels to correct device to enable model parallelism
1232
+ labels = labels.to(lm_logits.device)
1233
+ # Shift so that tokens < n predict n
1234
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1235
+ shift_labels = labels[..., 1:].contiguous()
1236
+ # Flatten the tokens
1237
+ loss_fct = CrossEntropyLoss()
1238
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1239
+
1240
+ if not return_dict:
1241
+ output = (lm_logits,) + transformer_outputs[1:]
1242
+ return ((loss,) + output) if loss is not None else output
1243
+
1244
+ return CausalLMOutputWithCrossAttentions(
1245
+ loss=loss,
1246
+ logits=lm_logits,
1247
+ past_key_values=transformer_outputs.past_key_values,
1248
+ hidden_states=transformer_outputs.hidden_states,
1249
+ attentions=transformer_outputs.attentions,
1250
+ cross_attentions=transformer_outputs.cross_attentions,
1251
+ )
1252
+
1253
+ @staticmethod
1254
+ def _reorder_cache(
1255
+ past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
1256
+ ) -> Tuple[Tuple[torch.Tensor]]:
1257
+ """
1258
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
1259
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
1260
+ beam_idx at every generation step.
1261
+ """
1262
+ return tuple(
1263
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
1264
+ for layer_past in past_key_values
1265
+ )
1266
+
1267
+
1268
+ @add_start_docstrings(
1269
+ """
1270
+ The MoT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
1271
+ RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
1272
+ input embeddings, the classification head takes as input the input of a specified classification token index in the
1273
+ input sequence).
1274
+ """,
1275
+ MOT_START_DOCSTRING,
1276
+ )
1277
+ @add_start_docstrings(
1278
+ """
1279
+ The MOT Model transformer with a sequence classification head on top (linear layer).
1280
+
1281
+ [`MoTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
1282
+ (e.g. GPT-1) do.
1283
+
1284
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1285
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
1286
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
1287
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
1288
+ each row of the batch).
1289
+ """,
1290
+ MOT_START_DOCSTRING,
1291
+ )
1292
+ class MoTForSequenceClassification(MoTPreTrainedModel):
1293
+ def __init__(self, config):
1294
+ super().__init__(config)
1295
+ self.num_labels = config.num_labels
1296
+ self.transformer = MoTModel(config)
1297
+ self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
1298
+
1299
+ # Model parallel
1300
+ self.model_parallel = False
1301
+ self.device_map = None
1302
+
1303
+ # Initialize weights and apply final processing
1304
+ self.post_init()
1305
+
1306
+ @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
1307
+ @add_code_sample_docstrings(
1308
+ checkpoint=_CHECKPOINT_FOR_DOC,
1309
+ output_type=SequenceClassifierOutputWithPast,
1310
+ config_class=_CONFIG_FOR_DOC,
1311
+ )
1312
+ def forward(
1313
+ self,
1314
+ input_ids: Optional[torch.LongTensor] = None,
1315
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1316
+ attention_mask: Optional[torch.FloatTensor] = None,
1317
+ token_type_ids: Optional[torch.LongTensor] = None,
1318
+ position_ids: Optional[torch.LongTensor] = None,
1319
+ head_mask: Optional[torch.FloatTensor] = None,
1320
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1321
+ labels: Optional[torch.LongTensor] = None,
1322
+ use_cache: Optional[bool] = None,
1323
+ output_attentions: Optional[bool] = None,
1324
+ output_hidden_states: Optional[bool] = None,
1325
+ return_dict: Optional[bool] = None,
1326
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
1327
+ r"""
1328
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1329
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1330
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1331
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1332
+ """
1333
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1334
+
1335
+ transformer_outputs = self.transformer(
1336
+ input_ids,
1337
+ past_key_values=past_key_values,
1338
+ attention_mask=attention_mask,
1339
+ token_type_ids=token_type_ids,
1340
+ position_ids=position_ids,
1341
+ head_mask=head_mask,
1342
+ inputs_embeds=inputs_embeds,
1343
+ use_cache=use_cache,
1344
+ output_attentions=output_attentions,
1345
+ output_hidden_states=output_hidden_states,
1346
+ return_dict=return_dict,
1347
+ )
1348
+ hidden_states = transformer_outputs[0]
1349
+ logits = self.score(hidden_states)
1350
+
1351
+ if input_ids is not None:
1352
+ batch_size, sequence_length = input_ids.shape[:2]
1353
+ else:
1354
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1355
+
1356
+ assert (
1357
+ self.config.pad_token_id is not None or batch_size == 1
1358
+ ), "Cannot handle batch sizes > 1 if no padding token is defined."
1359
+ if self.config.pad_token_id is None:
1360
+ sequence_lengths = -1
1361
+ else:
1362
+ if input_ids is not None:
1363
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1364
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1365
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1366
+ sequence_lengths = sequence_lengths.to(logits.device)
1367
+ else:
1368
+ sequence_lengths = -1
1369
+ logger.warning(
1370
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1371
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1372
+ )
1373
+
1374
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1375
+
1376
+ loss = None
1377
+ if labels is not None:
1378
+ if self.config.problem_type is None:
1379
+ if self.num_labels == 1:
1380
+ self.config.problem_type = "regression"
1381
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1382
+ self.config.problem_type = "single_label_classification"
1383
+ else:
1384
+ self.config.problem_type = "multi_label_classification"
1385
+
1386
+ if self.config.problem_type == "regression":
1387
+ loss_fct = MSELoss()
1388
+ if self.num_labels == 1:
1389
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1390
+ else:
1391
+ loss = loss_fct(pooled_logits, labels)
1392
+ elif self.config.problem_type == "single_label_classification":
1393
+ loss_fct = CrossEntropyLoss()
1394
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1395
+ elif self.config.problem_type == "multi_label_classification":
1396
+ loss_fct = BCEWithLogitsLoss()
1397
+ loss = loss_fct(pooled_logits, labels)
1398
+ if not return_dict:
1399
+ output = (pooled_logits,) + transformer_outputs[1:]
1400
+ return ((loss,) + output) if loss is not None else output
1401
+
1402
+ return SequenceClassifierOutputWithPast(
1403
+ loss=loss,
1404
+ logits=pooled_logits,
1405
+ past_key_values=transformer_outputs.past_key_values,
1406
+ hidden_states=transformer_outputs.hidden_states,
1407
+ attentions=transformer_outputs.attentions,
1408
+ )
1409
+
1410
+
1411
+ @add_start_docstrings(
1412
+ """
1413
+ MOT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1414
+ Named-Entity-Recognition (NER) tasks.
1415
+ """,
1416
+ MOT_START_DOCSTRING,
1417
+ )
1418
+ class MoTForTokenClassification(MoTPreTrainedModel):
1419
+ def __init__(self, config):
1420
+ super().__init__(config)
1421
+ self.num_labels = config.num_labels
1422
+
1423
+ self.transformer = MoTModel(config)
1424
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
1425
+ classifier_dropout = config.classifier_dropout
1426
+ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
1427
+ classifier_dropout = config.hidden_dropout
1428
+ else:
1429
+ classifier_dropout = 0.1
1430
+ self.dropout = nn.Dropout(classifier_dropout)
1431
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1432
+
1433
+ # Model parallel
1434
+ self.model_parallel = False
1435
+ self.device_map = None
1436
+
1437
+ # Initialize weights and apply final processing
1438
+ self.post_init()
1439
+
1440
+ @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
1441
+ # fmt: off
1442
+ @add_code_sample_docstrings(
1443
+ checkpoint=_CHECKPOINT_FOR_DOC,
1444
+ output_type=TokenClassifierOutput,
1445
+ config_class=_CONFIG_FOR_DOC,
1446
+ expected_loss=0.25,
1447
+ expected_output=[
1448
+ "Lead",
1449
+ "Lead",
1450
+ "Lead",
1451
+ "Position",
1452
+ "Lead",
1453
+ "Lead",
1454
+ "Lead",
1455
+ "Lead",
1456
+ "Lead",
1457
+ "Lead",
1458
+ "Lead",
1459
+ "Lead",
1460
+ ],
1461
+ )
1462
+ # fmt: on
1463
+ def forward(
1464
+ self,
1465
+ input_ids: Optional[torch.LongTensor] = None,
1466
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1467
+ attention_mask: Optional[torch.FloatTensor] = None,
1468
+ token_type_ids: Optional[torch.LongTensor] = None,
1469
+ position_ids: Optional[torch.LongTensor] = None,
1470
+ head_mask: Optional[torch.FloatTensor] = None,
1471
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1472
+ labels: Optional[torch.LongTensor] = None,
1473
+ use_cache: Optional[bool] = None,
1474
+ output_attentions: Optional[bool] = None,
1475
+ output_hidden_states: Optional[bool] = None,
1476
+ return_dict: Optional[bool] = None,
1477
+ ) -> Union[Tuple, TokenClassifierOutput]:
1478
+ r"""
1479
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1480
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1481
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1482
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1483
+ """
1484
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1485
+
1486
+ transformer_outputs = self.transformer(
1487
+ input_ids,
1488
+ past_key_values=past_key_values,
1489
+ attention_mask=attention_mask,
1490
+ token_type_ids=token_type_ids,
1491
+ position_ids=position_ids,
1492
+ head_mask=head_mask,
1493
+ inputs_embeds=inputs_embeds,
1494
+ use_cache=use_cache,
1495
+ output_attentions=output_attentions,
1496
+ output_hidden_states=output_hidden_states,
1497
+ return_dict=return_dict,
1498
+ )
1499
+
1500
+ hidden_states = transformer_outputs[0]
1501
+ hidden_states = self.dropout(hidden_states)
1502
+ logits = self.classifier(hidden_states)
1503
+
1504
+ loss = None
1505
+ if labels is not None:
1506
+ labels = labels.to(logits.device)
1507
+ loss_fct = CrossEntropyLoss()
1508
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1509
+
1510
+ if not return_dict:
1511
+ output = (logits,) + transformer_outputs[2:]
1512
+ return ((loss,) + output) if loss is not None else output
1513
+
1514
+ return TokenClassifierOutput(
1515
+ loss=loss,
1516
+ logits=logits,
1517
+ hidden_states=transformer_outputs.hidden_states,
1518
+ attentions=transformer_outputs.attentions,
1519
+ )
1520
+
1521
+
1522
+ @add_start_docstrings(
1523
+ """
1524
+ The MixtureOfTokens transformer with a span classification head on top for extractive question-answering tasks like
1525
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1526
+ """,
1527
+ MOT_START_DOCSTRING,
1528
+ )
1529
+ class MoTForQuestionAnswering(MoTPreTrainedModel):
1530
+ def __init__(self, config):
1531
+ super().__init__(config)
1532
+ self.num_labels = config.num_labels
1533
+ self.transformer = MoTModel(config)
1534
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1535
+
1536
+ # Model parallel
1537
+ self.model_parallel = False
1538
+ self.device_map = None
1539
+
1540
+ # Initialize weights and apply final processing
1541
+ self.post_init()
1542
+
1543
+ @add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1544
+ @add_code_sample_docstrings(
1545
+ checkpoint=_CHECKPOINT_FOR_DOC,
1546
+ output_type=QuestionAnsweringModelOutput,
1547
+ config_class=_CONFIG_FOR_DOC,
1548
+ )
1549
+ def forward(
1550
+ self,
1551
+ input_ids: Optional[torch.LongTensor] = None,
1552
+ attention_mask: Optional[torch.FloatTensor] = None,
1553
+ token_type_ids: Optional[torch.LongTensor] = None,
1554
+ position_ids: Optional[torch.LongTensor] = None,
1555
+ head_mask: Optional[torch.FloatTensor] = None,
1556
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1557
+ start_positions: Optional[torch.LongTensor] = None,
1558
+ end_positions: Optional[torch.LongTensor] = None,
1559
+ output_attentions: Optional[bool] = None,
1560
+ output_hidden_states: Optional[bool] = None,
1561
+ return_dict: Optional[bool] = None,
1562
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1563
+ r"""
1564
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1565
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1566
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1567
+ are not taken into account for computing the loss.
1568
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1569
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1570
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1571
+ are not taken into account for computing the loss.
1572
+ """
1573
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1574
+
1575
+ outputs = self.transformer(
1576
+ input_ids,
1577
+ attention_mask=attention_mask,
1578
+ token_type_ids=token_type_ids,
1579
+ position_ids=position_ids,
1580
+ head_mask=head_mask,
1581
+ inputs_embeds=inputs_embeds,
1582
+ output_attentions=output_attentions,
1583
+ output_hidden_states=output_hidden_states,
1584
+ return_dict=return_dict,
1585
+ )
1586
+
1587
+ sequence_output = outputs[0]
1588
+
1589
+ logits = self.qa_outputs(sequence_output)
1590
+ start_logits, end_logits = logits.split(1, dim=-1)
1591
+ start_logits = start_logits.squeeze(-1).contiguous()
1592
+ end_logits = end_logits.squeeze(-1).contiguous()
1593
+
1594
+ total_loss = None
1595
+ if start_positions is not None and end_positions is not None:
1596
+ # If we are on multi-GPU, split add a dimension
1597
+ if len(start_positions.size()) > 1:
1598
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1599
+ if len(end_positions.size()) > 1:
1600
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1601
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1602
+ ignored_index = start_logits.size(1)
1603
+ start_positions = start_positions.clamp(0, ignored_index)
1604
+ end_positions = end_positions.clamp(0, ignored_index)
1605
+
1606
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1607
+ start_loss = loss_fct(start_logits, start_positions)
1608
+ end_loss = loss_fct(end_logits, end_positions)
1609
+ total_loss = (start_loss + end_loss) / 2
1610
+
1611
+ if not return_dict:
1612
+ output = (start_logits, end_logits) + outputs[2:]
1613
+ return ((total_loss,) + output) if total_loss is not None else output
1614
+
1615
+ return QuestionAnsweringModelOutput(
1616
+ loss=total_loss,
1617
+ start_logits=start_logits,
1618
+ end_logits=end_logits,
1619
+ hidden_states=outputs.hidden_states,
1620
+ attentions=outputs.attentions,
1621
+ )