"Test upload 1"
Browse files- config.json +37 -0
- configuration_mot.py +168 -0
- model.safetensors +3 -0
- modeling_mot.py +1621 -0
config.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "jaszczur/mixture_of_tokens",
|
3 |
+
"activation_function": "gelu_new",
|
4 |
+
"architectures": [
|
5 |
+
"MoTModel"
|
6 |
+
],
|
7 |
+
"attn_pdrop": 0.1,
|
8 |
+
"auto_map": {
|
9 |
+
"AutoConfig": "configuration_mot.MoTConfig",
|
10 |
+
"AutoModel": "modeling_mot.MoTModel"
|
11 |
+
},
|
12 |
+
"bos_token_id": 50256,
|
13 |
+
"embd_pdrop": 0.1,
|
14 |
+
"emit_softmax_over_experts": false,
|
15 |
+
"eos_token_id": 50256,
|
16 |
+
"expert_size": 256,
|
17 |
+
"group_size": 32,
|
18 |
+
"init_scale": 1.0,
|
19 |
+
"initializer_range": 0.02,
|
20 |
+
"layer_norm_epsilon": 1e-05,
|
21 |
+
"model_type": "mot",
|
22 |
+
"n_embd": 512,
|
23 |
+
"n_expert": 256,
|
24 |
+
"n_head": 8,
|
25 |
+
"n_inner": 65536,
|
26 |
+
"n_layer": 8,
|
27 |
+
"n_positions": 1024,
|
28 |
+
"reorder_and_upcast_attn": false,
|
29 |
+
"resid_pdrop": 0.1,
|
30 |
+
"scale_attn_by_inverse_layer_idx": false,
|
31 |
+
"scale_attn_weights": true,
|
32 |
+
"torch_dtype": "float32",
|
33 |
+
"transformers_version": "4.42.0",
|
34 |
+
"use_cache": true,
|
35 |
+
"use_discrete_routing": false,
|
36 |
+
"vocab_size": 50257
|
37 |
+
}
|
configuration_mot.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""MixtureOfTokens configuration"""
|
17 |
+
|
18 |
+
from transformers import PretrainedConfig
|
19 |
+
from transformers.utils import logging
|
20 |
+
|
21 |
+
|
22 |
+
logger = logging.get_logger(__name__)
|
23 |
+
|
24 |
+
|
25 |
+
class MoTConfig(PretrainedConfig):
|
26 |
+
"""
|
27 |
+
This is the configuration class to store the configuration of a [`MoTModel`]. It is used to
|
28 |
+
instantiate a MixtureOfTokens model according to the specified arguments, defining the model architecture. Instantiating a
|
29 |
+
configuration with the defaults will yield a similar configuration to that of the MixtureOfTokens
|
30 |
+
[mot](https://huggingface.co/mot) architecture.
|
31 |
+
|
32 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
33 |
+
documentation from [`PretrainedConfig`] for more information.
|
34 |
+
|
35 |
+
|
36 |
+
Args:
|
37 |
+
vocab_size (`int`, *optional*, defaults to 50257):
|
38 |
+
Vocabulary size of the MixtureOfTokens model. Defines the number of different tokens that can be represented by the
|
39 |
+
`inputs_ids` passed when calling [`MoTModel`].
|
40 |
+
n_positions (`int`, *optional*, defaults to 1024):
|
41 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
42 |
+
just in case (e.g., 512 or 1024 or 2048).
|
43 |
+
n_embd (`int`, *optional*, defaults to 768):
|
44 |
+
Dimensionality of the embeddings and hidden states.
|
45 |
+
n_layer (`int`, *optional*, defaults to 12):
|
46 |
+
Number of hidden layers in the Transformer encoder.
|
47 |
+
n_head (`int`, *optional*, defaults to 12):
|
48 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
49 |
+
n_inner (`int`, *optional*):
|
50 |
+
Dimensionality of the inner feed-forward layers. `None` will set it to 4 times n_embd
|
51 |
+
n_expert (`int`, *optional*, defaults to 32):
|
52 |
+
The number of experts.
|
53 |
+
group_size (`int`, *optional*, defaults to 32):
|
54 |
+
The number of tokens per expert.
|
55 |
+
expert_size (`int`, *optional*):
|
56 |
+
The dimensionality of an expert. `None` will set it to n_inner / n_head.
|
57 |
+
init_scale (`float`, *optional*, defaults to 1.0):
|
58 |
+
The scaling factor for the initialization of MoTMLP weights. Inactive when creating through `from_pretrained`.
|
59 |
+
activation_function (`str`, *optional*, defaults to `"gelu_new"`):
|
60 |
+
Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
|
61 |
+
resid_pdrop (`float`, *optional*, defaults to 0.1):
|
62 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
63 |
+
embd_pdrop (`float`, *optional*, defaults to 0.1):
|
64 |
+
The dropout ratio for the embeddings.
|
65 |
+
attn_pdrop (`float`, *optional*, defaults to 0.1):
|
66 |
+
The dropout ratio for the attention.
|
67 |
+
layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
|
68 |
+
The epsilon to use in the layer normalization layers.
|
69 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
70 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
71 |
+
scale_attn_weights (`bool`, *optional*, defaults to `True`):
|
72 |
+
Scale attention weights by dividing by sqrt(hidden_size)..
|
73 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
74 |
+
Whether or not the model should return the last key/values attentions (not used by all models).
|
75 |
+
bos_token_id (`int`, *optional*, defaults to 50256):
|
76 |
+
Id of the beginning of sentence token in the vocabulary.
|
77 |
+
eos_token_id (`int`, *optional*, defaults to 50256):
|
78 |
+
Id of the end of sentence token in the vocabulary.
|
79 |
+
scale_attn_by_inverse_layer_idx (`bool`, *optional*, defaults to `False`):
|
80 |
+
Whether to additionally scale attention weights by `1 / layer_idx + 1`.
|
81 |
+
reorder_and_upcast_attn (`bool`, *optional*, defaults to `False`):
|
82 |
+
Whether to scale keys (K) prior to computing attention (dot-product) and upcast attention
|
83 |
+
dot-product/softmax to float() when training with mixed precision.
|
84 |
+
emit_softmax_over_experts (`bool`, *optional*, defaults to `False`):
|
85 |
+
Determines the method of redistributing aggregated tokens in the MoT MLP. By default the model uses the merge weights.
|
86 |
+
This flag switches it to taking a softmax over the experts.
|
87 |
+
use_discrete_routing (`bool`, *optional*, defaults to `False`):
|
88 |
+
Discretize the mixing, sending only to the expert with the highest weight. Inference-only.
|
89 |
+
|
90 |
+
Example:
|
91 |
+
|
92 |
+
```python
|
93 |
+
>>> from transformers import MoTConfig, MoTModel
|
94 |
+
|
95 |
+
>>> # Initializing a MoT configuration
|
96 |
+
>>> configuration = MoTConfig()
|
97 |
+
|
98 |
+
>>> # Initializing a model (with random weights) from the configuration
|
99 |
+
>>> model = MoTModel(configuration)
|
100 |
+
|
101 |
+
>>> # Accessing the model configuration
|
102 |
+
>>> configuration = model.config
|
103 |
+
```"""
|
104 |
+
|
105 |
+
model_type = "mot"
|
106 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
107 |
+
attribute_map = {
|
108 |
+
"hidden_size": "n_embd",
|
109 |
+
"max_position_embeddings": "n_positions",
|
110 |
+
"num_attention_heads": "n_head",
|
111 |
+
"num_hidden_layers": "n_layer",
|
112 |
+
}
|
113 |
+
|
114 |
+
def __init__(
|
115 |
+
self,
|
116 |
+
vocab_size=50257,
|
117 |
+
n_positions=1024,
|
118 |
+
n_embd=768,
|
119 |
+
n_layer=12,
|
120 |
+
n_head=12,
|
121 |
+
n_inner=None,
|
122 |
+
n_expert=32,
|
123 |
+
group_size=32,
|
124 |
+
expert_size=None,
|
125 |
+
init_scale=1.0,
|
126 |
+
activation_function="gelu_new",
|
127 |
+
resid_pdrop=0.1,
|
128 |
+
embd_pdrop=0.1,
|
129 |
+
attn_pdrop=0.1,
|
130 |
+
layer_norm_epsilon=1e-5,
|
131 |
+
initializer_range=0.02,
|
132 |
+
scale_attn_weights=True,
|
133 |
+
use_cache=True,
|
134 |
+
bos_token_id=50256,
|
135 |
+
eos_token_id=50256,
|
136 |
+
scale_attn_by_inverse_layer_idx=False,
|
137 |
+
reorder_and_upcast_attn=False,
|
138 |
+
emit_softmax_over_experts=False,
|
139 |
+
use_discrete_routing=False,
|
140 |
+
**kwargs,
|
141 |
+
):
|
142 |
+
self.vocab_size = vocab_size
|
143 |
+
self.n_positions = n_positions
|
144 |
+
self.n_embd = n_embd
|
145 |
+
self.n_layer = n_layer
|
146 |
+
self.n_head = n_head
|
147 |
+
self.n_inner = n_inner
|
148 |
+
self.n_expert = n_expert
|
149 |
+
self.group_size = group_size
|
150 |
+
self.expert_size = expert_size
|
151 |
+
self.init_scale = init_scale
|
152 |
+
self.activation_function = activation_function
|
153 |
+
self.resid_pdrop = resid_pdrop
|
154 |
+
self.embd_pdrop = embd_pdrop
|
155 |
+
self.attn_pdrop = attn_pdrop
|
156 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
157 |
+
self.initializer_range = initializer_range
|
158 |
+
self.scale_attn_weights = scale_attn_weights
|
159 |
+
self.use_cache = use_cache
|
160 |
+
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
|
161 |
+
self.reorder_and_upcast_attn = reorder_and_upcast_attn
|
162 |
+
self.emit_softmax_over_experts = emit_softmax_over_experts
|
163 |
+
self.use_discrete_routing = use_discrete_routing
|
164 |
+
|
165 |
+
self.bos_token_id = bos_token_id
|
166 |
+
self.eos_token_id = eos_token_id
|
167 |
+
|
168 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:98e0ab57892628979b24fcb28fed51eec69eecd6f8c9e6e1153a143f0219560e
|
3 |
+
size 2290399320
|
modeling_mot.py
ADDED
@@ -0,0 +1,1621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2024 The OpenAI Team Authors and HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch MixtureOfTokens model."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
import warnings
|
20 |
+
from typing import Optional, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import torch.nn.functional as F
|
24 |
+
import torch.utils.checkpoint
|
25 |
+
from torch import nn
|
26 |
+
from torch.cuda.amp import autocast
|
27 |
+
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
28 |
+
from torch.nn.init import trunc_normal_
|
29 |
+
|
30 |
+
from transformers.activations import ACT2FN
|
31 |
+
from transformers.modeling_outputs import (
|
32 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
33 |
+
CausalLMOutputWithCrossAttentions,
|
34 |
+
QuestionAnsweringModelOutput,
|
35 |
+
SequenceClassifierOutputWithPast,
|
36 |
+
TokenClassifierOutput,
|
37 |
+
)
|
38 |
+
from transformers.modeling_utils import PreTrainedModel
|
39 |
+
from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
40 |
+
from transformers.utils import (
|
41 |
+
add_code_sample_docstrings,
|
42 |
+
add_start_docstrings,
|
43 |
+
add_start_docstrings_to_model_forward,
|
44 |
+
logging,
|
45 |
+
)
|
46 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
47 |
+
from .configuration_mot import MoTConfig
|
48 |
+
|
49 |
+
|
50 |
+
logger = logging.get_logger(__name__)
|
51 |
+
|
52 |
+
_CHECKPOINT_FOR_DOC = "jaszczur/mixture_of_tokens"
|
53 |
+
_CONFIG_FOR_DOC = "MoTConfig"
|
54 |
+
|
55 |
+
|
56 |
+
def with_batch_size_alignment(forward_fn):
|
57 |
+
def _forward(self, x):
|
58 |
+
"""assumed ordering (batch, seq_len, dmodel)"""
|
59 |
+
size = x.size(self.sparsity_dim)
|
60 |
+
if size % self.group_size != 0:
|
61 |
+
if self.sparsity_dim == 1:
|
62 |
+
x = x.transpose(0, 1)
|
63 |
+
|
64 |
+
x = self.pad(x)
|
65 |
+
|
66 |
+
if self.sparsity_dim == 1:
|
67 |
+
x = forward_fn(self, x.transpose(0, 1))
|
68 |
+
return x[:, :size, :]
|
69 |
+
else:
|
70 |
+
x = forward_fn(self, x)
|
71 |
+
return x[:size, :, :]
|
72 |
+
else:
|
73 |
+
return forward_fn(self, x)
|
74 |
+
|
75 |
+
return _forward
|
76 |
+
|
77 |
+
|
78 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->MoT
|
79 |
+
class MoTAttention(nn.Module):
|
80 |
+
def __init__(self, config, is_cross_attention=False, layer_idx=None):
|
81 |
+
super().__init__()
|
82 |
+
self.config = config
|
83 |
+
max_positions = config.max_position_embeddings
|
84 |
+
self.register_buffer(
|
85 |
+
"bias",
|
86 |
+
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
|
87 |
+
1, 1, max_positions, max_positions
|
88 |
+
),
|
89 |
+
persistent=False,
|
90 |
+
)
|
91 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
|
92 |
+
|
93 |
+
self.embed_dim = config.hidden_size
|
94 |
+
self.num_heads = config.num_attention_heads
|
95 |
+
self.head_dim = self.embed_dim // self.num_heads
|
96 |
+
self.split_size = self.embed_dim
|
97 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
98 |
+
raise ValueError(
|
99 |
+
f"`embed_dim` must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
100 |
+
f" {self.num_heads})."
|
101 |
+
)
|
102 |
+
|
103 |
+
self.scale_attn_weights = config.scale_attn_weights
|
104 |
+
self.is_cross_attention = is_cross_attention
|
105 |
+
|
106 |
+
# Layer-wise attention scaling, reordering, and upcasting
|
107 |
+
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
|
108 |
+
self.layer_idx = layer_idx
|
109 |
+
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn
|
110 |
+
|
111 |
+
if self.is_cross_attention:
|
112 |
+
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
|
113 |
+
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
|
114 |
+
else:
|
115 |
+
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)
|
116 |
+
self.c_proj = Conv1D(self.embed_dim, self.embed_dim)
|
117 |
+
|
118 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
119 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
120 |
+
self.is_causal = True
|
121 |
+
|
122 |
+
self.pruned_heads = set()
|
123 |
+
|
124 |
+
def prune_heads(self, heads):
|
125 |
+
if len(heads) == 0:
|
126 |
+
return
|
127 |
+
heads, index = find_pruneable_heads_and_indices(heads, self.num_heads, self.head_dim, self.pruned_heads)
|
128 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
129 |
+
|
130 |
+
# Prune conv1d layers
|
131 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
132 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
133 |
+
|
134 |
+
# Update hyper params
|
135 |
+
self.split_size = (self.split_size // self.num_heads) * (self.num_heads - len(heads))
|
136 |
+
self.num_heads = self.num_heads - len(heads)
|
137 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
138 |
+
|
139 |
+
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
|
140 |
+
attn_weights = torch.matmul(query, key.transpose(-1, -2))
|
141 |
+
|
142 |
+
if self.scale_attn_weights:
|
143 |
+
attn_weights = attn_weights / torch.full(
|
144 |
+
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
|
145 |
+
)
|
146 |
+
|
147 |
+
# Layer-wise attention scaling
|
148 |
+
if self.scale_attn_by_inverse_layer_idx:
|
149 |
+
attn_weights = attn_weights / float(self.layer_idx + 1)
|
150 |
+
|
151 |
+
if not self.is_cross_attention:
|
152 |
+
# if only "normal" attention layer implements causal mask
|
153 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
154 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
155 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
156 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
157 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
158 |
+
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype, device=attn_weights.device)
|
159 |
+
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
|
160 |
+
|
161 |
+
if attention_mask is not None:
|
162 |
+
# Apply the attention mask
|
163 |
+
attn_weights = attn_weights + attention_mask
|
164 |
+
|
165 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
166 |
+
|
167 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op otherwise
|
168 |
+
attn_weights = attn_weights.type(value.dtype)
|
169 |
+
attn_weights = self.attn_dropout(attn_weights)
|
170 |
+
|
171 |
+
# Mask heads if we want to
|
172 |
+
if head_mask is not None:
|
173 |
+
attn_weights = attn_weights * head_mask
|
174 |
+
|
175 |
+
attn_output = torch.matmul(attn_weights, value)
|
176 |
+
|
177 |
+
return attn_output, attn_weights
|
178 |
+
|
179 |
+
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
|
180 |
+
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
|
181 |
+
bsz, num_heads, q_seq_len, dk = query.size()
|
182 |
+
_, _, k_seq_len, _ = key.size()
|
183 |
+
|
184 |
+
# Preallocate attn_weights for `baddbmm`
|
185 |
+
attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)
|
186 |
+
|
187 |
+
# Compute Scale Factor
|
188 |
+
scale_factor = 1.0
|
189 |
+
if self.scale_attn_weights:
|
190 |
+
scale_factor /= float(value.size(-1)) ** 0.5
|
191 |
+
|
192 |
+
if self.scale_attn_by_inverse_layer_idx:
|
193 |
+
scale_factor /= float(self.layer_idx + 1)
|
194 |
+
|
195 |
+
# Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
|
196 |
+
with autocast(enabled=False):
|
197 |
+
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
|
198 |
+
attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
|
199 |
+
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
|
200 |
+
|
201 |
+
if not self.is_cross_attention:
|
202 |
+
# if only "normal" attention layer implements causal mask
|
203 |
+
query_length, key_length = query.size(-2), key.size(-2)
|
204 |
+
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
|
205 |
+
mask_value = torch.finfo(attn_weights.dtype).min
|
206 |
+
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
|
207 |
+
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
|
208 |
+
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
|
209 |
+
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
|
210 |
+
|
211 |
+
if attention_mask is not None:
|
212 |
+
# Apply the attention mask
|
213 |
+
attn_weights = attn_weights + attention_mask
|
214 |
+
|
215 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
216 |
+
|
217 |
+
# Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
|
218 |
+
if attn_weights.dtype != torch.float32:
|
219 |
+
raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
|
220 |
+
attn_weights = attn_weights.type(value.dtype)
|
221 |
+
attn_weights = self.attn_dropout(attn_weights)
|
222 |
+
|
223 |
+
# Mask heads if we want to
|
224 |
+
if head_mask is not None:
|
225 |
+
attn_weights = attn_weights * head_mask
|
226 |
+
|
227 |
+
attn_output = torch.matmul(attn_weights, value)
|
228 |
+
|
229 |
+
return attn_output, attn_weights
|
230 |
+
|
231 |
+
def _split_heads(self, tensor, num_heads, attn_head_size):
|
232 |
+
"""
|
233 |
+
Splits hidden_size dim into attn_head_size and num_heads
|
234 |
+
"""
|
235 |
+
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
|
236 |
+
tensor = tensor.view(new_shape)
|
237 |
+
return tensor.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
238 |
+
|
239 |
+
def _merge_heads(self, tensor, num_heads, attn_head_size):
|
240 |
+
"""
|
241 |
+
Merges attn_head_size dim and num_attn_heads dim into hidden_size
|
242 |
+
"""
|
243 |
+
tensor = tensor.permute(0, 2, 1, 3).contiguous()
|
244 |
+
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
|
245 |
+
return tensor.view(new_shape)
|
246 |
+
|
247 |
+
def forward(
|
248 |
+
self,
|
249 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
250 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
251 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
252 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
253 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
254 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
255 |
+
use_cache: Optional[bool] = False,
|
256 |
+
output_attentions: Optional[bool] = False,
|
257 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
258 |
+
if encoder_hidden_states is not None:
|
259 |
+
if not hasattr(self, "q_attn"):
|
260 |
+
raise ValueError(
|
261 |
+
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
262 |
+
"Please make sure to instantiate class with `MoTAttention(..., is_cross_attention=True)`."
|
263 |
+
)
|
264 |
+
|
265 |
+
query = self.q_attn(hidden_states)
|
266 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
267 |
+
attention_mask = encoder_attention_mask
|
268 |
+
else:
|
269 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
270 |
+
|
271 |
+
query = self._split_heads(query, self.num_heads, self.head_dim)
|
272 |
+
key = self._split_heads(key, self.num_heads, self.head_dim)
|
273 |
+
value = self._split_heads(value, self.num_heads, self.head_dim)
|
274 |
+
|
275 |
+
if layer_past is not None:
|
276 |
+
past_key, past_value = layer_past
|
277 |
+
key = torch.cat((past_key, key), dim=-2)
|
278 |
+
value = torch.cat((past_value, value), dim=-2)
|
279 |
+
|
280 |
+
if use_cache is True:
|
281 |
+
present = (key, value)
|
282 |
+
else:
|
283 |
+
present = None
|
284 |
+
|
285 |
+
if self.reorder_and_upcast_attn:
|
286 |
+
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask)
|
287 |
+
else:
|
288 |
+
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
|
289 |
+
|
290 |
+
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
|
291 |
+
attn_output = self.c_proj(attn_output)
|
292 |
+
attn_output = self.resid_dropout(attn_output)
|
293 |
+
|
294 |
+
outputs = (attn_output, present)
|
295 |
+
if output_attentions:
|
296 |
+
outputs += (attn_weights,)
|
297 |
+
|
298 |
+
return outputs # a, present, (attentions)
|
299 |
+
|
300 |
+
|
301 |
+
class MoTMLP(nn.Module):
|
302 |
+
r"""
|
303 |
+
Implementation of the Mixture of Tokens Sparse MLP module.
|
304 |
+
"""
|
305 |
+
|
306 |
+
def __init__(self, inner_dim: int, config: MoTConfig, sparsity_dim: int = 0, init_type: str = "kaiming_uniform"):
|
307 |
+
super().__init__()
|
308 |
+
|
309 |
+
self.d_model: int = config.n_embd
|
310 |
+
self.d_ff: int = config.n_inner if inner_dim is None else inner_dim
|
311 |
+
self.n_expert: int = config.n_expert
|
312 |
+
self.group_size: int = config.group_size
|
313 |
+
self.sparsity_dim: int = sparsity_dim
|
314 |
+
self.expert_size: int = config.expert_size
|
315 |
+
self.temperature: float = config.temperature
|
316 |
+
self.act = ACT2FN[config.activation_function]
|
317 |
+
|
318 |
+
self.init_type: str = init_type
|
319 |
+
self.init_scale: float = config.init_scale
|
320 |
+
|
321 |
+
self.emit_softmax_over_experts: bool = config.emit_softmax_over_experts
|
322 |
+
self.use_discrete_routing: bool = config.use_discrete_routing
|
323 |
+
|
324 |
+
if self.n_expert is not None:
|
325 |
+
if self.d_ff % self.n_expert:
|
326 |
+
self.d_ff += self.n_expert - (self.d_ff % self.n_expert)
|
327 |
+
warnings.warn("d_ff should be divisible by n_expert, padding d_ff to be divisible by n_expert")
|
328 |
+
self.expert_size = self.d_ff // self.n_expert
|
329 |
+
elif self.expert_size is not None:
|
330 |
+
if self.d_ff % self.expert_size:
|
331 |
+
self.d_ff += self.expert_size - (self.d_ff % self.expert_size)
|
332 |
+
warnings.warn("d_ff should be divisible by expert_size, padding d_ff to be divisible by expert_size")
|
333 |
+
self.n_expert = self.d_ff // self.expert_size
|
334 |
+
else:
|
335 |
+
raise ValueError("Either expert_size or n_expert should be provided")
|
336 |
+
|
337 |
+
self.lin1 = nn.Parameter(
|
338 |
+
self.get_init_weight(
|
339 |
+
(self.n_expert, self.d_model, self.expert_size),
|
340 |
+
fan_in=self.d_model,
|
341 |
+
init_type=self.init_type,
|
342 |
+
scale=self.init_scale,
|
343 |
+
)
|
344 |
+
)
|
345 |
+
|
346 |
+
self.lin2 = nn.Parameter(
|
347 |
+
self.get_init_weight(
|
348 |
+
(self.n_expert, self.expert_size, self.d_model),
|
349 |
+
fan_in=self.expert_size,
|
350 |
+
init_type=self.init_type,
|
351 |
+
scale=self.init_scale,
|
352 |
+
)
|
353 |
+
)
|
354 |
+
|
355 |
+
self.controller = nn.Parameter(
|
356 |
+
self.get_init_weight(
|
357 |
+
(self.d_model, self.n_expert),
|
358 |
+
fan_in=self.d_model,
|
359 |
+
init_type=self.init_type,
|
360 |
+
scale=self.init_scale,
|
361 |
+
)
|
362 |
+
)
|
363 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
364 |
+
|
365 |
+
@staticmethod
|
366 |
+
def argmax_one_hot(x: torch.Tensor, dim: int):
|
367 |
+
max_values, _ = x.max(dim=dim, keepdim=True)
|
368 |
+
return torch.where(
|
369 |
+
condition=x == max_values,
|
370 |
+
input=torch.Tensor([1.0]).to(dtype=x.dtype, device=x.device),
|
371 |
+
other=torch.Tensor([0.0]).to(dtype=x.dtype, device=x.device),
|
372 |
+
) # potentially make it the value itself? torch.where(x == max_values, x, 0.0)
|
373 |
+
|
374 |
+
def get_init_weight(self, shape, fan_in, init_type, scale, dtype=torch.float32):
|
375 |
+
if init_type == "kaiming_uniform":
|
376 |
+
return self.init_kaiming_uniform(shape=shape, fan_in=fan_in, scale=scale, dtype=dtype)
|
377 |
+
elif init_type == "truncated_normal":
|
378 |
+
return self.init_truncated_normal(shape=shape, fan_in=fan_in, scale=scale, dtype=dtype)
|
379 |
+
else:
|
380 |
+
raise ValueError(f"Unknown init_type: {init_type}")
|
381 |
+
|
382 |
+
@staticmethod
|
383 |
+
def init_kaiming_uniform(shape, fan_in, scale, dtype=torch.float32):
|
384 |
+
range_ = scale * (3 / fan_in) ** 0.5
|
385 |
+
return torch.zeros(shape, dtype=dtype).uniform_(-range_, range_)
|
386 |
+
|
387 |
+
@staticmethod
|
388 |
+
def init_truncated_normal(shape, fan_in, scale, dtype=torch.float32):
|
389 |
+
std = (scale / fan_in) ** 0.5
|
390 |
+
low = -2 * scale
|
391 |
+
high = 2 * scale
|
392 |
+
t = torch.zeros(shape, dtype=dtype)
|
393 |
+
return trunc_normal_(t, mean=0.0, std=std, a=low, b=high)
|
394 |
+
|
395 |
+
@staticmethod
|
396 |
+
def stable_softmax_temperature(x: torch.Tensor, temperature: float, dim: int = -1) -> torch.Tensor:
|
397 |
+
return F.softmax(x / temperature, dim=dim)
|
398 |
+
|
399 |
+
def pad(self, x):
|
400 |
+
size = x.size(0)
|
401 |
+
ceiling = torch.ceil(torch.tensor(size / self.group_size).float())
|
402 |
+
new_batch_size = self.group_size * ceiling.int().item()
|
403 |
+
padding_size = new_batch_size - size
|
404 |
+
logger.debug("Padding batch size from %d to %d", size, new_batch_size)
|
405 |
+
|
406 |
+
# Create a zero sequence for padding
|
407 |
+
zero_sequence = torch.zeros_like(x[0:1])
|
408 |
+
padding_sequences = zero_sequence.repeat(padding_size, 1, 1)
|
409 |
+
|
410 |
+
return torch.cat([x, padding_sequences], dim=0)
|
411 |
+
|
412 |
+
@with_batch_size_alignment
|
413 |
+
def forward(self, x):
|
414 |
+
x = self.group_tokens(x)
|
415 |
+
merge_weights, emit_weights = self.calculate_mixed_tokens_with_weights(x)
|
416 |
+
x = self.merge_map_emit(x, merge_weights, emit_weights)
|
417 |
+
x = self.redistribute_tokens(x)
|
418 |
+
x = self.dropout(x)
|
419 |
+
return x
|
420 |
+
|
421 |
+
def group_tokens(self, x):
|
422 |
+
"""
|
423 |
+
Reshape code so the axis to split into groups is on position 1, and then group over said axis.
|
424 |
+
e.g.:
|
425 |
+
- if we group tokens from different sequences in a batch (sparsity = 0), we need to put the batch dimension to position 1.
|
426 |
+
- if we group tokens within one sequence, the dimension to split into groups is already on position 1, hence we leave it as is.
|
427 |
+
|
428 |
+
free_dimension is the dimension on position 0 after reshape
|
429 |
+
split_dimension is the dimension on position 1 - the one to split into groups
|
430 |
+
|
431 |
+
:param x: normal input tensor of shape (batch, seq_len, dmodel)
|
432 |
+
:return: x of shape (free_dimension, split_dimension // group_size, group_size , dmodel)
|
433 |
+
"""
|
434 |
+
assert len(x.shape) == 3, "incorrect shape of a tensor, expected a 3D tensor"
|
435 |
+
assert (
|
436 |
+
x.size(-1) == self.d_model
|
437 |
+
), f"expected the last dimension of input tensor to be d_model = {self.d_model}"
|
438 |
+
|
439 |
+
if self.sparsity_dim == 0:
|
440 |
+
x = x.transpose(0, 1)
|
441 |
+
elif self.sparsity_dim != 1:
|
442 |
+
raise NotImplementedError
|
443 |
+
|
444 |
+
free_dimension = x.size(1)
|
445 |
+
assert (
|
446 |
+
free_dimension % self.group_size == 0
|
447 |
+
), f"free dimension = {free_dimension} should be divisible by group size = {self.group_size}"
|
448 |
+
|
449 |
+
x = x.view(x.size(0), -1, self.group_size, self.d_model)
|
450 |
+
return x
|
451 |
+
|
452 |
+
def redistribute_tokens(self, x):
|
453 |
+
"""
|
454 |
+
An inverse operation to group_tokens.
|
455 |
+
"""
|
456 |
+
assert len(x.shape) == 4, "incorrect shape of a tensor, expected a 4D tensor"
|
457 |
+
|
458 |
+
x = x.view(x.size(0), -1, self.d_model)
|
459 |
+
if self.sparsity_dim == 0:
|
460 |
+
x = x.transpose(0, 1)
|
461 |
+
elif self.sparsity_dim != 1:
|
462 |
+
raise NotImplementedError
|
463 |
+
|
464 |
+
return x
|
465 |
+
|
466 |
+
def calculate_mixed_tokens_with_weights(self, x):
|
467 |
+
"""
|
468 |
+
This function calculates merge and emit weights based on the input tensor, using a controller matrix.
|
469 |
+
The merge weights determine the aggregation of tokens within a group, and emit weights govern the redistribution
|
470 |
+
of the aggregated token back to the original tokens. Temperature scaling is applied to the logits, and optional
|
471 |
+
discrete routing can be used to obtain one-hot representations of the weights.
|
472 |
+
"""
|
473 |
+
# shape of x is (free_dimension, split_dimension // group_size, group_size, dmodel)
|
474 |
+
merge_logits = torch.matmul(x, self.controller)
|
475 |
+
# self.update_cache_for_logging("merge_logits", merge_logits)
|
476 |
+
|
477 |
+
# shape of merge_logits is (free_dimension, aggr_dimension // group_size, group_size, n_expert)
|
478 |
+
temp_merge = self.temperature
|
479 |
+
temp_emit = self.temperature
|
480 |
+
|
481 |
+
merge_softmax_dim = -2
|
482 |
+
emit_softmax_dim = -1 if self.emit_softmax_over_experts else -2
|
483 |
+
|
484 |
+
merge_weights = self.stable_softmax_temperature(merge_logits, temp_merge, dim=merge_softmax_dim)
|
485 |
+
|
486 |
+
# by default we use the same weights for emitting and merging, but if the temperature is learnable or we want to take softmax over experts for emitting, we will use different weights
|
487 |
+
if isinstance(temp_merge, torch.nn.Parameter) or self.emit_softmax_over_experts:
|
488 |
+
emit_weights = self.stable_softmax_temperature(merge_logits, temp_emit, dim=emit_softmax_dim)
|
489 |
+
else:
|
490 |
+
emit_weights = merge_weights
|
491 |
+
|
492 |
+
if self.use_discrete_routing:
|
493 |
+
merge_weights = self.argmax_one_hot(merge_weights, dim=merge_softmax_dim)
|
494 |
+
emit_weights = self.argmax_one_hot(emit_weights, dim=emit_softmax_dim)
|
495 |
+
return merge_weights, emit_weights
|
496 |
+
|
497 |
+
def merge_map_emit(self, x, merge_weights, emit_weights):
|
498 |
+
"""
|
499 |
+
:param x: input reshaped to (free_dimension, split_dimension // group_size, group_size, dmodel)
|
500 |
+
:param merge_weights: weights for merging tokens within a group, shape (free_dimension, split_dimension // group_size, group_size, n_expert)
|
501 |
+
:param emit_weights: weights for emitting tokens within a group, shape (free_dimension, split_dimension // group_size, group_size, n_expert)
|
502 |
+
:return: tensor of token updates of shape (free_dimension, split_dimension // group_size, group_size, dmodel)
|
503 |
+
"""
|
504 |
+
x = torch.matmul(
|
505 |
+
merge_weights.transpose(-1, -2),
|
506 |
+
x,
|
507 |
+
)
|
508 |
+
# x shape is (free_dimension, split_dimension // group_size, n_expert, dmodel) ||| lin1 shape is (n_expert, dmodel, expert_size)
|
509 |
+
x = torch.bmm(x.view(-1, self.n_expert, self.d_model).transpose(0, 1), self.lin1)
|
510 |
+
x = self.act(x)
|
511 |
+
# x shape is (n_expert, free_dimension * aggr_dimension // group_size, expert_size) ||| lin2 shape is (n_expert, expert_size, dmodel)
|
512 |
+
x = torch.bmm(x, self.lin2)
|
513 |
+
# x shape is (n_expert, free_dimension * aggr_dimension // group_size, dmodel)
|
514 |
+
|
515 |
+
# merge_weights shape is (free_dimension, aggr_dimension // group_size, group_size, n_expert)
|
516 |
+
# view x to be (n_expert, free_dimension, aggr_dimension // group_size, dmodel)
|
517 |
+
# permute it to be (free_dimension, aggr_dimension // group_size, n_expert, dmodel)
|
518 |
+
x = torch.matmul(
|
519 |
+
emit_weights,
|
520 |
+
x.view(x.size(0), emit_weights.size(0), -1, self.d_model).permute(1, 2, 0, 3),
|
521 |
+
)
|
522 |
+
|
523 |
+
return x
|
524 |
+
|
525 |
+
|
526 |
+
MoT_ATTENTION_CLASSES = {
|
527 |
+
"eager": MoTAttention,
|
528 |
+
}
|
529 |
+
|
530 |
+
|
531 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Block with GPT2->MoT
|
532 |
+
class MoTBlock(nn.Module):
|
533 |
+
def __init__(self, config, layer_idx=None):
|
534 |
+
super().__init__()
|
535 |
+
hidden_size = config.hidden_size
|
536 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
537 |
+
attention_class = MoT_ATTENTION_CLASSES[config._attn_implementation]
|
538 |
+
|
539 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
540 |
+
self.attn = attention_class(config=config, layer_idx=layer_idx)
|
541 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
542 |
+
|
543 |
+
if config.add_cross_attention:
|
544 |
+
self.crossattention = attention_class(config=config, is_cross_attention=True, layer_idx=layer_idx)
|
545 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
546 |
+
|
547 |
+
self.mlp = MoTMLP(inner_dim, config)
|
548 |
+
|
549 |
+
def forward(
|
550 |
+
self,
|
551 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
552 |
+
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
553 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
554 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
555 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
556 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
557 |
+
use_cache: Optional[bool] = False,
|
558 |
+
output_attentions: Optional[bool] = False,
|
559 |
+
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
560 |
+
residual = hidden_states
|
561 |
+
hidden_states = self.ln_1(hidden_states)
|
562 |
+
attn_outputs = self.attn(
|
563 |
+
hidden_states,
|
564 |
+
layer_past=layer_past,
|
565 |
+
attention_mask=attention_mask,
|
566 |
+
head_mask=head_mask,
|
567 |
+
use_cache=use_cache,
|
568 |
+
output_attentions=output_attentions,
|
569 |
+
)
|
570 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
571 |
+
outputs = attn_outputs[1:]
|
572 |
+
# residual connection
|
573 |
+
hidden_states = attn_output + residual
|
574 |
+
|
575 |
+
if encoder_hidden_states is not None:
|
576 |
+
# add one self-attention block for cross-attention
|
577 |
+
if not hasattr(self, "crossattention"):
|
578 |
+
raise ValueError(
|
579 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
|
580 |
+
"cross-attention layers by setting `config.add_cross_attention=True`"
|
581 |
+
)
|
582 |
+
residual = hidden_states
|
583 |
+
hidden_states = self.ln_cross_attn(hidden_states)
|
584 |
+
cross_attn_outputs = self.crossattention(
|
585 |
+
hidden_states,
|
586 |
+
attention_mask=attention_mask,
|
587 |
+
head_mask=head_mask,
|
588 |
+
encoder_hidden_states=encoder_hidden_states,
|
589 |
+
encoder_attention_mask=encoder_attention_mask,
|
590 |
+
output_attentions=output_attentions,
|
591 |
+
)
|
592 |
+
attn_output = cross_attn_outputs[0]
|
593 |
+
# residual connection
|
594 |
+
hidden_states = residual + attn_output
|
595 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
596 |
+
|
597 |
+
residual = hidden_states
|
598 |
+
hidden_states = self.ln_2(hidden_states)
|
599 |
+
feed_forward_hidden_states = self.mlp(hidden_states)
|
600 |
+
# residual connection
|
601 |
+
hidden_states = residual + feed_forward_hidden_states
|
602 |
+
|
603 |
+
if use_cache:
|
604 |
+
outputs = (hidden_states,) + outputs
|
605 |
+
else:
|
606 |
+
outputs = (hidden_states,) + outputs[1:]
|
607 |
+
|
608 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
609 |
+
|
610 |
+
|
611 |
+
# Mostly copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel with GPT2->MoT,gpt2->mot,OpenAI GPT-2->MixtureOfTokens, but removed references to TF and FlashAttention2
|
612 |
+
class MoTPreTrainedModel(PreTrainedModel):
|
613 |
+
"""
|
614 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
615 |
+
models.
|
616 |
+
"""
|
617 |
+
|
618 |
+
config_class = MoTConfig
|
619 |
+
base_model_prefix = "transformer"
|
620 |
+
is_parallelizable = True
|
621 |
+
supports_gradient_checkpointing = True
|
622 |
+
_no_split_modules = ["MoTBlock"]
|
623 |
+
_skip_keys_device_placement = "past_key_values"
|
624 |
+
|
625 |
+
def __init__(self, *inputs, **kwargs):
|
626 |
+
super().__init__(*inputs, **kwargs)
|
627 |
+
|
628 |
+
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2PreTrainedModel._init_weights with GPT2->MoT,gpt2->mot,OpenAI GPT-2->MixtureOfTokens
|
629 |
+
def _init_weights(self, module):
|
630 |
+
"""Initialize the weights."""
|
631 |
+
if isinstance(module, (nn.Linear, Conv1D)):
|
632 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
633 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
634 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
635 |
+
if module.bias is not None:
|
636 |
+
module.bias.data.zero_()
|
637 |
+
elif isinstance(module, nn.Embedding):
|
638 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
639 |
+
if module.padding_idx is not None:
|
640 |
+
module.weight.data[module.padding_idx].zero_()
|
641 |
+
elif isinstance(module, nn.LayerNorm):
|
642 |
+
module.bias.data.zero_()
|
643 |
+
module.weight.data.fill_(1.0)
|
644 |
+
|
645 |
+
# Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
|
646 |
+
# > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
|
647 |
+
# > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
|
648 |
+
# > -- GPT-2 :: https://openai.com/blog/better-language-models/
|
649 |
+
#
|
650 |
+
# Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
|
651 |
+
for name, p in module.named_parameters():
|
652 |
+
if name == "c_proj.weight":
|
653 |
+
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
|
654 |
+
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
|
655 |
+
|
656 |
+
|
657 |
+
MOT_START_DOCSTRING = r"""
|
658 |
+
|
659 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
660 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
661 |
+
etc.)
|
662 |
+
|
663 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
664 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
665 |
+
and behavior.
|
666 |
+
|
667 |
+
Parameters:
|
668 |
+
config ([`MoTConfig`]): Model configuration class with all the parameters of the model.
|
669 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
670 |
+
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
671 |
+
"""
|
672 |
+
|
673 |
+
MOT_INPUTS_DOCSTRING = r"""
|
674 |
+
Args:
|
675 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
|
676 |
+
`input_ids_length` = `sequence_length` if `past_key_values` is `None` else
|
677 |
+
`past_key_values[0][0].shape[-2]` (`sequence_length` of input past key value states). Indices of input
|
678 |
+
sequence tokens in the vocabulary.
|
679 |
+
|
680 |
+
If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
|
681 |
+
`input_ids`.
|
682 |
+
|
683 |
+
Indices can be obtained using [`GPT2Tokenizer`]. See [`PreTrainedTokenizer.encode`] and
|
684 |
+
[`PreTrainedTokenizer.__call__`] for details.
|
685 |
+
|
686 |
+
[What are input IDs?](../glossary#input-ids)
|
687 |
+
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
688 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
689 |
+
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
690 |
+
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
691 |
+
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
692 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
693 |
+
|
694 |
+
- 1 for tokens that are **not masked**,
|
695 |
+
- 0 for tokens that are **masked**.
|
696 |
+
|
697 |
+
If `past_key_values` is used, `attention_mask` needs to contain the masking strategy that was used for
|
698 |
+
`past_key_values`. In other words, the `attention_mask` always has to have the length:
|
699 |
+
`len(past_key_values) + len(input_ids)`
|
700 |
+
|
701 |
+
[What are attention masks?](../glossary#attention-mask)
|
702 |
+
token_type_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
|
703 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
|
704 |
+
1]`:
|
705 |
+
|
706 |
+
- 0 corresponds to a *sentence A* token,
|
707 |
+
- 1 corresponds to a *sentence B* token.
|
708 |
+
|
709 |
+
[What are token type IDs?](../glossary#token-type-ids)
|
710 |
+
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
711 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
|
712 |
+
config.max_position_embeddings - 1]`.
|
713 |
+
|
714 |
+
[What are position IDs?](../glossary#position-ids)
|
715 |
+
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
716 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
717 |
+
|
718 |
+
- 1 indicates the head is **not masked**,
|
719 |
+
- 0 indicates the head is **masked**.
|
720 |
+
|
721 |
+
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
722 |
+
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
|
723 |
+
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
|
724 |
+
model's internal embedding lookup matrix.
|
725 |
+
|
726 |
+
If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
|
727 |
+
`past_key_values`).
|
728 |
+
use_cache (`bool`, *optional*):
|
729 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
730 |
+
`past_key_values`).
|
731 |
+
output_attentions (`bool`, *optional*):
|
732 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
733 |
+
tensors for more detail.
|
734 |
+
output_hidden_states (`bool`, *optional*):
|
735 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
736 |
+
more detail.
|
737 |
+
return_dict (`bool`, *optional*):
|
738 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
739 |
+
"""
|
740 |
+
PARALLELIZE_DOCSTRING = r"""
|
741 |
+
This is an experimental feature and is a subject to change at a moment's notice.
|
742 |
+
|
743 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
744 |
+
it will evenly distribute blocks across all devices.
|
745 |
+
|
746 |
+
Args:
|
747 |
+
device_map (`Dict[int, list]`, optional, defaults to None):
|
748 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
749 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
750 |
+
have fewer attention modules mapped to it than other devices. For reference, the mot models have the
|
751 |
+
following number of attention modules:
|
752 |
+
|
753 |
+
- mot: 12
|
754 |
+
- mot-medium: 24
|
755 |
+
- mot-large: 36
|
756 |
+
- mot-xl: 48
|
757 |
+
|
758 |
+
Example:
|
759 |
+
|
760 |
+
```python
|
761 |
+
# Here is an example of a device map on a machine with 4 GPUs using mot-xl, which has a total of 48 attention modules:
|
762 |
+
model = MoTLMHeadModel.from_pretrained("jaszczur/mixture_of_tokens")
|
763 |
+
device_map = {
|
764 |
+
0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
765 |
+
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
766 |
+
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
767 |
+
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
|
768 |
+
}
|
769 |
+
model.parallelize(device_map)
|
770 |
+
```
|
771 |
+
"""
|
772 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
773 |
+
Moves the model to cpu from a model parallel state.
|
774 |
+
|
775 |
+
Example:
|
776 |
+
|
777 |
+
```python
|
778 |
+
# On a 4 GPU machine with mot-large:
|
779 |
+
model = MoTLMHeadModel.from_pretrained("jaszczur/mixture_of_tokens")
|
780 |
+
device_map = {
|
781 |
+
0: [0, 1, 2, 3, 4, 5, 6, 7],
|
782 |
+
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
783 |
+
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
784 |
+
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
|
785 |
+
}
|
786 |
+
model.parallelize(device_map) # Splits the model across several devices
|
787 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
788 |
+
```
|
789 |
+
"""
|
790 |
+
|
791 |
+
|
792 |
+
@add_start_docstrings(
|
793 |
+
"The bare MOT Model transformer outputting raw hidden-states without any specific head on top.",
|
794 |
+
MOT_START_DOCSTRING,
|
795 |
+
)
|
796 |
+
class MoTModel(MoTPreTrainedModel):
|
797 |
+
def __init__(self, config):
|
798 |
+
super().__init__(config)
|
799 |
+
|
800 |
+
self.embed_dim = config.hidden_size
|
801 |
+
|
802 |
+
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
|
803 |
+
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
804 |
+
|
805 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
806 |
+
self.h = nn.ModuleList([MoTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
|
807 |
+
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
808 |
+
|
809 |
+
# Model parallel
|
810 |
+
self.model_parallel = False
|
811 |
+
self.device_map = None
|
812 |
+
self.gradient_checkpointing = False
|
813 |
+
|
814 |
+
# Initialize weights and apply final processing
|
815 |
+
self.post_init()
|
816 |
+
|
817 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
818 |
+
def parallelize(self, device_map=None):
|
819 |
+
# Check validity of device_map
|
820 |
+
warnings.warn(
|
821 |
+
"`MoTModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your"
|
822 |
+
" model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
823 |
+
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'h.0': 0, 'h.1': 1,"
|
824 |
+
" ...}",
|
825 |
+
FutureWarning,
|
826 |
+
)
|
827 |
+
self.device_map = (
|
828 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
829 |
+
)
|
830 |
+
assert_device_map(self.device_map, len(self.h))
|
831 |
+
self.model_parallel = True
|
832 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
833 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
834 |
+
self.wte = self.wte.to(self.first_device)
|
835 |
+
self.wpe = self.wpe.to(self.first_device)
|
836 |
+
# Load onto devices
|
837 |
+
for k, v in self.device_map.items():
|
838 |
+
for block in v:
|
839 |
+
cuda_device = "cuda:" + str(k)
|
840 |
+
self.h[block] = self.h[block].to(cuda_device)
|
841 |
+
# ln_f to last
|
842 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
843 |
+
|
844 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
845 |
+
def deparallelize(self):
|
846 |
+
warnings.warn(
|
847 |
+
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
848 |
+
FutureWarning,
|
849 |
+
)
|
850 |
+
self.model_parallel = False
|
851 |
+
self.device_map = None
|
852 |
+
self.first_device = "cpu"
|
853 |
+
self.last_device = "cpu"
|
854 |
+
self.wte = self.wte.to("cpu")
|
855 |
+
self.wpe = self.wpe.to("cpu")
|
856 |
+
for index in range(len(self.h)):
|
857 |
+
self.h[index] = self.h[index].to("cpu")
|
858 |
+
self.ln_f = self.ln_f.to("cpu")
|
859 |
+
torch.cuda.empty_cache()
|
860 |
+
|
861 |
+
def get_input_embeddings(self):
|
862 |
+
return self.wte
|
863 |
+
|
864 |
+
def set_input_embeddings(self, new_embeddings):
|
865 |
+
self.wte = new_embeddings
|
866 |
+
|
867 |
+
def _prune_heads(self, heads_to_prune):
|
868 |
+
"""
|
869 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
870 |
+
"""
|
871 |
+
for layer, heads in heads_to_prune.items():
|
872 |
+
self.h[layer].attn.prune_heads(heads)
|
873 |
+
|
874 |
+
@add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
|
875 |
+
@add_code_sample_docstrings(
|
876 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
877 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
878 |
+
config_class=_CONFIG_FOR_DOC,
|
879 |
+
)
|
880 |
+
def forward(
|
881 |
+
self,
|
882 |
+
input_ids: Optional[torch.LongTensor] = None,
|
883 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
884 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
885 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
886 |
+
position_ids: Optional[torch.LongTensor] = None,
|
887 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
888 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
889 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
890 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
891 |
+
use_cache: Optional[bool] = None,
|
892 |
+
output_attentions: Optional[bool] = None,
|
893 |
+
output_hidden_states: Optional[bool] = None,
|
894 |
+
return_dict: Optional[bool] = None,
|
895 |
+
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
|
896 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
897 |
+
output_hidden_states = (
|
898 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
899 |
+
)
|
900 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
901 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
902 |
+
|
903 |
+
if input_ids is not None and inputs_embeds is not None:
|
904 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
905 |
+
elif input_ids is not None:
|
906 |
+
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
907 |
+
input_shape = input_ids.size()
|
908 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
909 |
+
batch_size = input_ids.shape[0]
|
910 |
+
elif inputs_embeds is not None:
|
911 |
+
input_shape = inputs_embeds.size()[:-1]
|
912 |
+
batch_size = inputs_embeds.shape[0]
|
913 |
+
else:
|
914 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
915 |
+
|
916 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
917 |
+
|
918 |
+
if token_type_ids is not None:
|
919 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
920 |
+
|
921 |
+
if past_key_values is None:
|
922 |
+
past_length = 0
|
923 |
+
past_key_values = tuple([None] * len(self.h))
|
924 |
+
else:
|
925 |
+
past_length = past_key_values[0][0].size(-2)
|
926 |
+
if position_ids is None:
|
927 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
928 |
+
position_ids = position_ids.unsqueeze(0)
|
929 |
+
|
930 |
+
# MoTAttention mask.
|
931 |
+
if attention_mask is not None:
|
932 |
+
if batch_size <= 0:
|
933 |
+
raise ValueError("batch_size has to be defined and > 0")
|
934 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
935 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
936 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
937 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
938 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
939 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
940 |
+
attention_mask = attention_mask[:, None, None, :]
|
941 |
+
|
942 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
943 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
944 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
945 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
946 |
+
# effectively the same as removing these entirely.
|
947 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
948 |
+
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
949 |
+
|
950 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
951 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
952 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
953 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
954 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
955 |
+
if encoder_attention_mask is None:
|
956 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
957 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
958 |
+
else:
|
959 |
+
encoder_attention_mask = None
|
960 |
+
|
961 |
+
# Prepare head mask if needed
|
962 |
+
# 1.0 in head_mask indicate we keep the head
|
963 |
+
# attention_probs has shape bsz x n_heads x N x N
|
964 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
965 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
966 |
+
|
967 |
+
if inputs_embeds is None:
|
968 |
+
inputs_embeds = self.wte(input_ids)
|
969 |
+
position_embeds = self.wpe(position_ids)
|
970 |
+
hidden_states = inputs_embeds + position_embeds
|
971 |
+
|
972 |
+
if token_type_ids is not None:
|
973 |
+
token_type_embeds = self.wte(token_type_ids)
|
974 |
+
hidden_states = hidden_states + token_type_embeds
|
975 |
+
|
976 |
+
hidden_states = self.drop(hidden_states)
|
977 |
+
|
978 |
+
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
979 |
+
|
980 |
+
if self.gradient_checkpointing and self.training:
|
981 |
+
if use_cache:
|
982 |
+
logger.warning_once(
|
983 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
984 |
+
)
|
985 |
+
use_cache = False
|
986 |
+
|
987 |
+
presents = () if use_cache else None
|
988 |
+
all_self_attentions = () if output_attentions else None
|
989 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
990 |
+
all_hidden_states = () if output_hidden_states else None
|
991 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
992 |
+
# Model parallel
|
993 |
+
if self.model_parallel:
|
994 |
+
torch.cuda.set_device(hidden_states.device)
|
995 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
996 |
+
if layer_past is not None:
|
997 |
+
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
998 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
999 |
+
if attention_mask is not None:
|
1000 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
1001 |
+
if isinstance(head_mask, torch.Tensor):
|
1002 |
+
head_mask = head_mask.to(hidden_states.device)
|
1003 |
+
if output_hidden_states:
|
1004 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1005 |
+
|
1006 |
+
if self.gradient_checkpointing and self.training:
|
1007 |
+
outputs = self._gradient_checkpointing_func(
|
1008 |
+
block.__call__,
|
1009 |
+
hidden_states,
|
1010 |
+
None,
|
1011 |
+
attention_mask,
|
1012 |
+
head_mask[i],
|
1013 |
+
encoder_hidden_states,
|
1014 |
+
encoder_attention_mask,
|
1015 |
+
use_cache,
|
1016 |
+
output_attentions,
|
1017 |
+
)
|
1018 |
+
else:
|
1019 |
+
outputs = block(
|
1020 |
+
hidden_states,
|
1021 |
+
layer_past=layer_past,
|
1022 |
+
attention_mask=attention_mask,
|
1023 |
+
head_mask=head_mask[i],
|
1024 |
+
encoder_hidden_states=encoder_hidden_states,
|
1025 |
+
encoder_attention_mask=encoder_attention_mask,
|
1026 |
+
use_cache=use_cache,
|
1027 |
+
output_attentions=output_attentions,
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
hidden_states = outputs[0]
|
1031 |
+
if use_cache is True:
|
1032 |
+
presents = presents + (outputs[1],)
|
1033 |
+
|
1034 |
+
if output_attentions:
|
1035 |
+
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
1036 |
+
if self.config.add_cross_attention:
|
1037 |
+
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
1038 |
+
|
1039 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
1040 |
+
if self.model_parallel:
|
1041 |
+
for k, v in self.device_map.items():
|
1042 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
1043 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
1044 |
+
|
1045 |
+
hidden_states = self.ln_f(hidden_states)
|
1046 |
+
|
1047 |
+
hidden_states = hidden_states.view(output_shape)
|
1048 |
+
# Add last hidden state
|
1049 |
+
if output_hidden_states:
|
1050 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
1051 |
+
|
1052 |
+
if not return_dict:
|
1053 |
+
return tuple(
|
1054 |
+
v
|
1055 |
+
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
1056 |
+
if v is not None
|
1057 |
+
)
|
1058 |
+
|
1059 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
1060 |
+
last_hidden_state=hidden_states,
|
1061 |
+
past_key_values=presents,
|
1062 |
+
hidden_states=all_hidden_states,
|
1063 |
+
attentions=all_self_attentions,
|
1064 |
+
cross_attentions=all_cross_attentions,
|
1065 |
+
)
|
1066 |
+
|
1067 |
+
|
1068 |
+
@add_start_docstrings(
|
1069 |
+
"""
|
1070 |
+
The MOT Model transformer with a language modeling head on top (linear layer with weights tied to the input
|
1071 |
+
embeddings).
|
1072 |
+
""",
|
1073 |
+
MOT_START_DOCSTRING,
|
1074 |
+
)
|
1075 |
+
class MoTLMHeadModel(MoTPreTrainedModel):
|
1076 |
+
_tied_weights_keys = ["lm_head.weight"]
|
1077 |
+
|
1078 |
+
def __init__(self, config):
|
1079 |
+
super().__init__(config)
|
1080 |
+
self.transformer = MoTModel(config)
|
1081 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
1082 |
+
|
1083 |
+
# Model parallel
|
1084 |
+
self.model_parallel = False
|
1085 |
+
self.device_map = None
|
1086 |
+
|
1087 |
+
# Initialize weights and apply final processing
|
1088 |
+
self.post_init()
|
1089 |
+
|
1090 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
1091 |
+
def parallelize(self, device_map=None):
|
1092 |
+
warnings.warn(
|
1093 |
+
"`MoTLMHeadModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
|
1094 |
+
" your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
|
1095 |
+
" `device_map` but it needs to be a dictionary module_name to device, so for instance {'transformer.h.0':"
|
1096 |
+
" 0, 'transformer.h.1': 1, ...}",
|
1097 |
+
FutureWarning,
|
1098 |
+
)
|
1099 |
+
self.device_map = (
|
1100 |
+
get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
|
1101 |
+
if device_map is None
|
1102 |
+
else device_map
|
1103 |
+
)
|
1104 |
+
assert_device_map(self.device_map, len(self.transformer.h))
|
1105 |
+
self.transformer.parallelize(self.device_map)
|
1106 |
+
self.lm_head = self.lm_head.to(self.transformer.first_device)
|
1107 |
+
self.model_parallel = True
|
1108 |
+
|
1109 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
1110 |
+
def deparallelize(self):
|
1111 |
+
warnings.warn(
|
1112 |
+
"Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
|
1113 |
+
FutureWarning,
|
1114 |
+
)
|
1115 |
+
self.transformer.deparallelize()
|
1116 |
+
self.transformer = self.transformer.to("cpu")
|
1117 |
+
self.lm_head = self.lm_head.to("cpu")
|
1118 |
+
self.model_parallel = False
|
1119 |
+
torch.cuda.empty_cache()
|
1120 |
+
|
1121 |
+
def get_output_embeddings(self):
|
1122 |
+
return self.lm_head
|
1123 |
+
|
1124 |
+
def set_output_embeddings(self, new_embeddings):
|
1125 |
+
self.lm_head = new_embeddings
|
1126 |
+
|
1127 |
+
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
|
1128 |
+
token_type_ids = kwargs.get("token_type_ids", None)
|
1129 |
+
# Omit tokens covered by past_key_values
|
1130 |
+
if past_key_values:
|
1131 |
+
past_length = past_key_values[0][0].shape[2]
|
1132 |
+
|
1133 |
+
# Some generation methods already pass only the last input ID
|
1134 |
+
if input_ids.shape[1] > past_length:
|
1135 |
+
remove_prefix_length = past_length
|
1136 |
+
else:
|
1137 |
+
# Default to old behavior: keep only final ID
|
1138 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
1139 |
+
|
1140 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
1141 |
+
if token_type_ids is not None:
|
1142 |
+
token_type_ids = token_type_ids[:, -input_ids.shape[1] :]
|
1143 |
+
|
1144 |
+
attention_mask = kwargs.get("attention_mask", None)
|
1145 |
+
position_ids = kwargs.get("position_ids", None)
|
1146 |
+
|
1147 |
+
if attention_mask is not None and position_ids is None:
|
1148 |
+
# create position_ids on the fly for batch generation
|
1149 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
1150 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
1151 |
+
if past_key_values:
|
1152 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
1153 |
+
else:
|
1154 |
+
position_ids = None
|
1155 |
+
|
1156 |
+
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
1157 |
+
if inputs_embeds is not None and past_key_values is None:
|
1158 |
+
model_inputs = {"inputs_embeds": inputs_embeds}
|
1159 |
+
else:
|
1160 |
+
model_inputs = {"input_ids": input_ids}
|
1161 |
+
|
1162 |
+
model_inputs.update(
|
1163 |
+
{
|
1164 |
+
"past_key_values": past_key_values,
|
1165 |
+
"use_cache": kwargs.get("use_cache"),
|
1166 |
+
"position_ids": position_ids,
|
1167 |
+
"attention_mask": attention_mask,
|
1168 |
+
"token_type_ids": token_type_ids,
|
1169 |
+
}
|
1170 |
+
)
|
1171 |
+
|
1172 |
+
return model_inputs
|
1173 |
+
|
1174 |
+
@add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
|
1175 |
+
@add_code_sample_docstrings(
|
1176 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1177 |
+
output_type=CausalLMOutputWithCrossAttentions,
|
1178 |
+
config_class=_CONFIG_FOR_DOC,
|
1179 |
+
)
|
1180 |
+
def forward(
|
1181 |
+
self,
|
1182 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1183 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1184 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1185 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1186 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1187 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1188 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1189 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1190 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
1191 |
+
labels: Optional[torch.LongTensor] = None,
|
1192 |
+
use_cache: Optional[bool] = None,
|
1193 |
+
output_attentions: Optional[bool] = None,
|
1194 |
+
output_hidden_states: Optional[bool] = None,
|
1195 |
+
return_dict: Optional[bool] = None,
|
1196 |
+
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
1197 |
+
r"""
|
1198 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1199 |
+
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
1200 |
+
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
1201 |
+
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
|
1202 |
+
"""
|
1203 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1204 |
+
|
1205 |
+
transformer_outputs = self.transformer(
|
1206 |
+
input_ids,
|
1207 |
+
past_key_values=past_key_values,
|
1208 |
+
attention_mask=attention_mask,
|
1209 |
+
token_type_ids=token_type_ids,
|
1210 |
+
position_ids=position_ids,
|
1211 |
+
head_mask=head_mask,
|
1212 |
+
inputs_embeds=inputs_embeds,
|
1213 |
+
encoder_hidden_states=encoder_hidden_states,
|
1214 |
+
encoder_attention_mask=encoder_attention_mask,
|
1215 |
+
use_cache=use_cache,
|
1216 |
+
output_attentions=output_attentions,
|
1217 |
+
output_hidden_states=output_hidden_states,
|
1218 |
+
return_dict=return_dict,
|
1219 |
+
)
|
1220 |
+
hidden_states = transformer_outputs[0]
|
1221 |
+
|
1222 |
+
# Set device for model parallelism
|
1223 |
+
if self.model_parallel:
|
1224 |
+
torch.cuda.set_device(self.transformer.first_device)
|
1225 |
+
hidden_states = hidden_states.to(self.lm_head.weight.device)
|
1226 |
+
|
1227 |
+
lm_logits = self.lm_head(hidden_states)
|
1228 |
+
|
1229 |
+
loss = None
|
1230 |
+
if labels is not None:
|
1231 |
+
# move labels to correct device to enable model parallelism
|
1232 |
+
labels = labels.to(lm_logits.device)
|
1233 |
+
# Shift so that tokens < n predict n
|
1234 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
1235 |
+
shift_labels = labels[..., 1:].contiguous()
|
1236 |
+
# Flatten the tokens
|
1237 |
+
loss_fct = CrossEntropyLoss()
|
1238 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
1239 |
+
|
1240 |
+
if not return_dict:
|
1241 |
+
output = (lm_logits,) + transformer_outputs[1:]
|
1242 |
+
return ((loss,) + output) if loss is not None else output
|
1243 |
+
|
1244 |
+
return CausalLMOutputWithCrossAttentions(
|
1245 |
+
loss=loss,
|
1246 |
+
logits=lm_logits,
|
1247 |
+
past_key_values=transformer_outputs.past_key_values,
|
1248 |
+
hidden_states=transformer_outputs.hidden_states,
|
1249 |
+
attentions=transformer_outputs.attentions,
|
1250 |
+
cross_attentions=transformer_outputs.cross_attentions,
|
1251 |
+
)
|
1252 |
+
|
1253 |
+
@staticmethod
|
1254 |
+
def _reorder_cache(
|
1255 |
+
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
1256 |
+
) -> Tuple[Tuple[torch.Tensor]]:
|
1257 |
+
"""
|
1258 |
+
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
1259 |
+
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
1260 |
+
beam_idx at every generation step.
|
1261 |
+
"""
|
1262 |
+
return tuple(
|
1263 |
+
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
1264 |
+
for layer_past in past_key_values
|
1265 |
+
)
|
1266 |
+
|
1267 |
+
|
1268 |
+
@add_start_docstrings(
|
1269 |
+
"""
|
1270 |
+
The MoT Model transformer with a language modeling and a multiple-choice classification head on top e.g. for
|
1271 |
+
RocStories/SWAG tasks. The two heads are two linear layers. The language modeling head has its weights tied to the
|
1272 |
+
input embeddings, the classification head takes as input the input of a specified classification token index in the
|
1273 |
+
input sequence).
|
1274 |
+
""",
|
1275 |
+
MOT_START_DOCSTRING,
|
1276 |
+
)
|
1277 |
+
@add_start_docstrings(
|
1278 |
+
"""
|
1279 |
+
The MOT Model transformer with a sequence classification head on top (linear layer).
|
1280 |
+
|
1281 |
+
[`MoTForSequenceClassification`] uses the last token in order to do the classification, as other causal models
|
1282 |
+
(e.g. GPT-1) do.
|
1283 |
+
|
1284 |
+
Since it does classification on the last token, it requires to know the position of the last token. If a
|
1285 |
+
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
|
1286 |
+
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
|
1287 |
+
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
|
1288 |
+
each row of the batch).
|
1289 |
+
""",
|
1290 |
+
MOT_START_DOCSTRING,
|
1291 |
+
)
|
1292 |
+
class MoTForSequenceClassification(MoTPreTrainedModel):
|
1293 |
+
def __init__(self, config):
|
1294 |
+
super().__init__(config)
|
1295 |
+
self.num_labels = config.num_labels
|
1296 |
+
self.transformer = MoTModel(config)
|
1297 |
+
self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)
|
1298 |
+
|
1299 |
+
# Model parallel
|
1300 |
+
self.model_parallel = False
|
1301 |
+
self.device_map = None
|
1302 |
+
|
1303 |
+
# Initialize weights and apply final processing
|
1304 |
+
self.post_init()
|
1305 |
+
|
1306 |
+
@add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
|
1307 |
+
@add_code_sample_docstrings(
|
1308 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1309 |
+
output_type=SequenceClassifierOutputWithPast,
|
1310 |
+
config_class=_CONFIG_FOR_DOC,
|
1311 |
+
)
|
1312 |
+
def forward(
|
1313 |
+
self,
|
1314 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1315 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1316 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1317 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1318 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1319 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1320 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1321 |
+
labels: Optional[torch.LongTensor] = None,
|
1322 |
+
use_cache: Optional[bool] = None,
|
1323 |
+
output_attentions: Optional[bool] = None,
|
1324 |
+
output_hidden_states: Optional[bool] = None,
|
1325 |
+
return_dict: Optional[bool] = None,
|
1326 |
+
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
|
1327 |
+
r"""
|
1328 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1329 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1330 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1331 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1332 |
+
"""
|
1333 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1334 |
+
|
1335 |
+
transformer_outputs = self.transformer(
|
1336 |
+
input_ids,
|
1337 |
+
past_key_values=past_key_values,
|
1338 |
+
attention_mask=attention_mask,
|
1339 |
+
token_type_ids=token_type_ids,
|
1340 |
+
position_ids=position_ids,
|
1341 |
+
head_mask=head_mask,
|
1342 |
+
inputs_embeds=inputs_embeds,
|
1343 |
+
use_cache=use_cache,
|
1344 |
+
output_attentions=output_attentions,
|
1345 |
+
output_hidden_states=output_hidden_states,
|
1346 |
+
return_dict=return_dict,
|
1347 |
+
)
|
1348 |
+
hidden_states = transformer_outputs[0]
|
1349 |
+
logits = self.score(hidden_states)
|
1350 |
+
|
1351 |
+
if input_ids is not None:
|
1352 |
+
batch_size, sequence_length = input_ids.shape[:2]
|
1353 |
+
else:
|
1354 |
+
batch_size, sequence_length = inputs_embeds.shape[:2]
|
1355 |
+
|
1356 |
+
assert (
|
1357 |
+
self.config.pad_token_id is not None or batch_size == 1
|
1358 |
+
), "Cannot handle batch sizes > 1 if no padding token is defined."
|
1359 |
+
if self.config.pad_token_id is None:
|
1360 |
+
sequence_lengths = -1
|
1361 |
+
else:
|
1362 |
+
if input_ids is not None:
|
1363 |
+
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
1364 |
+
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
1365 |
+
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
1366 |
+
sequence_lengths = sequence_lengths.to(logits.device)
|
1367 |
+
else:
|
1368 |
+
sequence_lengths = -1
|
1369 |
+
logger.warning(
|
1370 |
+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
|
1371 |
+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
|
1372 |
+
)
|
1373 |
+
|
1374 |
+
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
|
1375 |
+
|
1376 |
+
loss = None
|
1377 |
+
if labels is not None:
|
1378 |
+
if self.config.problem_type is None:
|
1379 |
+
if self.num_labels == 1:
|
1380 |
+
self.config.problem_type = "regression"
|
1381 |
+
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
|
1382 |
+
self.config.problem_type = "single_label_classification"
|
1383 |
+
else:
|
1384 |
+
self.config.problem_type = "multi_label_classification"
|
1385 |
+
|
1386 |
+
if self.config.problem_type == "regression":
|
1387 |
+
loss_fct = MSELoss()
|
1388 |
+
if self.num_labels == 1:
|
1389 |
+
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
|
1390 |
+
else:
|
1391 |
+
loss = loss_fct(pooled_logits, labels)
|
1392 |
+
elif self.config.problem_type == "single_label_classification":
|
1393 |
+
loss_fct = CrossEntropyLoss()
|
1394 |
+
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
|
1395 |
+
elif self.config.problem_type == "multi_label_classification":
|
1396 |
+
loss_fct = BCEWithLogitsLoss()
|
1397 |
+
loss = loss_fct(pooled_logits, labels)
|
1398 |
+
if not return_dict:
|
1399 |
+
output = (pooled_logits,) + transformer_outputs[1:]
|
1400 |
+
return ((loss,) + output) if loss is not None else output
|
1401 |
+
|
1402 |
+
return SequenceClassifierOutputWithPast(
|
1403 |
+
loss=loss,
|
1404 |
+
logits=pooled_logits,
|
1405 |
+
past_key_values=transformer_outputs.past_key_values,
|
1406 |
+
hidden_states=transformer_outputs.hidden_states,
|
1407 |
+
attentions=transformer_outputs.attentions,
|
1408 |
+
)
|
1409 |
+
|
1410 |
+
|
1411 |
+
@add_start_docstrings(
|
1412 |
+
"""
|
1413 |
+
MOT Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1414 |
+
Named-Entity-Recognition (NER) tasks.
|
1415 |
+
""",
|
1416 |
+
MOT_START_DOCSTRING,
|
1417 |
+
)
|
1418 |
+
class MoTForTokenClassification(MoTPreTrainedModel):
|
1419 |
+
def __init__(self, config):
|
1420 |
+
super().__init__(config)
|
1421 |
+
self.num_labels = config.num_labels
|
1422 |
+
|
1423 |
+
self.transformer = MoTModel(config)
|
1424 |
+
if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
|
1425 |
+
classifier_dropout = config.classifier_dropout
|
1426 |
+
elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
|
1427 |
+
classifier_dropout = config.hidden_dropout
|
1428 |
+
else:
|
1429 |
+
classifier_dropout = 0.1
|
1430 |
+
self.dropout = nn.Dropout(classifier_dropout)
|
1431 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1432 |
+
|
1433 |
+
# Model parallel
|
1434 |
+
self.model_parallel = False
|
1435 |
+
self.device_map = None
|
1436 |
+
|
1437 |
+
# Initialize weights and apply final processing
|
1438 |
+
self.post_init()
|
1439 |
+
|
1440 |
+
@add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING)
|
1441 |
+
# fmt: off
|
1442 |
+
@add_code_sample_docstrings(
|
1443 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1444 |
+
output_type=TokenClassifierOutput,
|
1445 |
+
config_class=_CONFIG_FOR_DOC,
|
1446 |
+
expected_loss=0.25,
|
1447 |
+
expected_output=[
|
1448 |
+
"Lead",
|
1449 |
+
"Lead",
|
1450 |
+
"Lead",
|
1451 |
+
"Position",
|
1452 |
+
"Lead",
|
1453 |
+
"Lead",
|
1454 |
+
"Lead",
|
1455 |
+
"Lead",
|
1456 |
+
"Lead",
|
1457 |
+
"Lead",
|
1458 |
+
"Lead",
|
1459 |
+
"Lead",
|
1460 |
+
],
|
1461 |
+
)
|
1462 |
+
# fmt: on
|
1463 |
+
def forward(
|
1464 |
+
self,
|
1465 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1466 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
1467 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1468 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1469 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1470 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1471 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1472 |
+
labels: Optional[torch.LongTensor] = None,
|
1473 |
+
use_cache: Optional[bool] = None,
|
1474 |
+
output_attentions: Optional[bool] = None,
|
1475 |
+
output_hidden_states: Optional[bool] = None,
|
1476 |
+
return_dict: Optional[bool] = None,
|
1477 |
+
) -> Union[Tuple, TokenClassifierOutput]:
|
1478 |
+
r"""
|
1479 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
1480 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
|
1481 |
+
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
|
1482 |
+
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1483 |
+
"""
|
1484 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1485 |
+
|
1486 |
+
transformer_outputs = self.transformer(
|
1487 |
+
input_ids,
|
1488 |
+
past_key_values=past_key_values,
|
1489 |
+
attention_mask=attention_mask,
|
1490 |
+
token_type_ids=token_type_ids,
|
1491 |
+
position_ids=position_ids,
|
1492 |
+
head_mask=head_mask,
|
1493 |
+
inputs_embeds=inputs_embeds,
|
1494 |
+
use_cache=use_cache,
|
1495 |
+
output_attentions=output_attentions,
|
1496 |
+
output_hidden_states=output_hidden_states,
|
1497 |
+
return_dict=return_dict,
|
1498 |
+
)
|
1499 |
+
|
1500 |
+
hidden_states = transformer_outputs[0]
|
1501 |
+
hidden_states = self.dropout(hidden_states)
|
1502 |
+
logits = self.classifier(hidden_states)
|
1503 |
+
|
1504 |
+
loss = None
|
1505 |
+
if labels is not None:
|
1506 |
+
labels = labels.to(logits.device)
|
1507 |
+
loss_fct = CrossEntropyLoss()
|
1508 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1509 |
+
|
1510 |
+
if not return_dict:
|
1511 |
+
output = (logits,) + transformer_outputs[2:]
|
1512 |
+
return ((loss,) + output) if loss is not None else output
|
1513 |
+
|
1514 |
+
return TokenClassifierOutput(
|
1515 |
+
loss=loss,
|
1516 |
+
logits=logits,
|
1517 |
+
hidden_states=transformer_outputs.hidden_states,
|
1518 |
+
attentions=transformer_outputs.attentions,
|
1519 |
+
)
|
1520 |
+
|
1521 |
+
|
1522 |
+
@add_start_docstrings(
|
1523 |
+
"""
|
1524 |
+
The MixtureOfTokens transformer with a span classification head on top for extractive question-answering tasks like
|
1525 |
+
SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1526 |
+
""",
|
1527 |
+
MOT_START_DOCSTRING,
|
1528 |
+
)
|
1529 |
+
class MoTForQuestionAnswering(MoTPreTrainedModel):
|
1530 |
+
def __init__(self, config):
|
1531 |
+
super().__init__(config)
|
1532 |
+
self.num_labels = config.num_labels
|
1533 |
+
self.transformer = MoTModel(config)
|
1534 |
+
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
1535 |
+
|
1536 |
+
# Model parallel
|
1537 |
+
self.model_parallel = False
|
1538 |
+
self.device_map = None
|
1539 |
+
|
1540 |
+
# Initialize weights and apply final processing
|
1541 |
+
self.post_init()
|
1542 |
+
|
1543 |
+
@add_start_docstrings_to_model_forward(MOT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1544 |
+
@add_code_sample_docstrings(
|
1545 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
1546 |
+
output_type=QuestionAnsweringModelOutput,
|
1547 |
+
config_class=_CONFIG_FOR_DOC,
|
1548 |
+
)
|
1549 |
+
def forward(
|
1550 |
+
self,
|
1551 |
+
input_ids: Optional[torch.LongTensor] = None,
|
1552 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
1553 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
1554 |
+
position_ids: Optional[torch.LongTensor] = None,
|
1555 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
1556 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
1557 |
+
start_positions: Optional[torch.LongTensor] = None,
|
1558 |
+
end_positions: Optional[torch.LongTensor] = None,
|
1559 |
+
output_attentions: Optional[bool] = None,
|
1560 |
+
output_hidden_states: Optional[bool] = None,
|
1561 |
+
return_dict: Optional[bool] = None,
|
1562 |
+
) -> Union[Tuple, QuestionAnsweringModelOutput]:
|
1563 |
+
r"""
|
1564 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1565 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1566 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1567 |
+
are not taken into account for computing the loss.
|
1568 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
1569 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1570 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
1571 |
+
are not taken into account for computing the loss.
|
1572 |
+
"""
|
1573 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1574 |
+
|
1575 |
+
outputs = self.transformer(
|
1576 |
+
input_ids,
|
1577 |
+
attention_mask=attention_mask,
|
1578 |
+
token_type_ids=token_type_ids,
|
1579 |
+
position_ids=position_ids,
|
1580 |
+
head_mask=head_mask,
|
1581 |
+
inputs_embeds=inputs_embeds,
|
1582 |
+
output_attentions=output_attentions,
|
1583 |
+
output_hidden_states=output_hidden_states,
|
1584 |
+
return_dict=return_dict,
|
1585 |
+
)
|
1586 |
+
|
1587 |
+
sequence_output = outputs[0]
|
1588 |
+
|
1589 |
+
logits = self.qa_outputs(sequence_output)
|
1590 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1591 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
1592 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
1593 |
+
|
1594 |
+
total_loss = None
|
1595 |
+
if start_positions is not None and end_positions is not None:
|
1596 |
+
# If we are on multi-GPU, split add a dimension
|
1597 |
+
if len(start_positions.size()) > 1:
|
1598 |
+
start_positions = start_positions.squeeze(-1).to(start_logits.device)
|
1599 |
+
if len(end_positions.size()) > 1:
|
1600 |
+
end_positions = end_positions.squeeze(-1).to(end_logits.device)
|
1601 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1602 |
+
ignored_index = start_logits.size(1)
|
1603 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
1604 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
1605 |
+
|
1606 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1607 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1608 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1609 |
+
total_loss = (start_loss + end_loss) / 2
|
1610 |
+
|
1611 |
+
if not return_dict:
|
1612 |
+
output = (start_logits, end_logits) + outputs[2:]
|
1613 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1614 |
+
|
1615 |
+
return QuestionAnsweringModelOutput(
|
1616 |
+
loss=total_loss,
|
1617 |
+
start_logits=start_logits,
|
1618 |
+
end_logits=end_logits,
|
1619 |
+
hidden_states=outputs.hidden_states,
|
1620 |
+
attentions=outputs.attentions,
|
1621 |
+
)
|