dwzhu commited on
Commit
928713e
·
verified ·
1 Parent(s): cd87170

Upload 2 files

Browse files
Files changed (2) hide show
  1. configuration_e5rope.py +141 -0
  2. modeling_e5rope.py +1306 -0
configuration_e5rope.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # This file has been modified from the configuration_roformer.py file in the transformers library.
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ E5Rope model configuration"""
17
+
18
+ from collections import OrderedDict
19
+ from typing import Mapping
20
+
21
+ from transformers.configuration_utils import PretrainedConfig
22
+ from transformers.onnx import OnnxConfig
23
+ from transformers.utils import logging
24
+
25
+
26
+ logger = logging.get_logger(__name__)
27
+
28
+
29
+
30
+ class E5RopeConfig(PretrainedConfig):
31
+ r"""
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 50000):
35
+ Vocabulary size of the E5Rope model. Defines the number of different tokens that can be represented by
36
+ the `inputs_ids` passed when calling [`E5RopeModel`] or [`TFE5RopeModel`].
37
+ embedding_size (`int`, *optional*, defaults to None):
38
+ Dimensionality of the encoder layers and the pooler layer. Defaults to the `hidden_size` if not provided.
39
+ hidden_size (`int`, *optional*, defaults to 768):
40
+ Dimension of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 12):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 12):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 3072):
46
+ Dimension of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
51
+ The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
53
+ The dropout ratio for the attention probabilities.
54
+ max_position_embeddings (`int`, *optional*, defaults to 1536):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 1536).
57
+ type_vocab_size (`int`, *optional*, defaults to 2):
58
+ The vocabulary size of the `token_type_ids` passed when calling [`E5RopeModel`] or [`TFE5RopeModel`].
59
+ initializer_range (`float`, *optional*, defaults to 0.02):
60
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
61
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
62
+ The epsilon used by the layer normalization layers.
63
+ is_decoder (`bool`, *optional*, defaults to `False`):
64
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
65
+ use_cache (`bool`, *optional*, defaults to `True`):
66
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
67
+ relevant if `config.is_decoder=True`.
68
+ rotary_value (`bool`, *optional*, defaults to `False`):
69
+ Whether or not apply rotary position embeddings on value layer.
70
+
71
+ rope_theta (`float`, *optional*, defaults to 10000):
72
+ Frequency base for RoPE.
73
+ use_pose (`bool`, *optional*, defaults to `False`):
74
+ Whether or not to use positional skip-wise training for long context. https://arxiv.org/abs/2309.10400
75
+ pose_target_len (`int`, *optional*, defaults to None):
76
+ target context length if use_pose is True
77
+
78
+ """
79
+
80
+ model_type = "e5rope"
81
+
82
+ def __init__(
83
+ self,
84
+ vocab_size=50000,
85
+ embedding_size=None,
86
+ hidden_size=768,
87
+ num_hidden_layers=12,
88
+ num_attention_heads=12,
89
+ intermediate_size=3072,
90
+ hidden_act="gelu",
91
+ hidden_dropout_prob=0.1,
92
+ attention_probs_dropout_prob=0.1,
93
+ max_position_embeddings=1536,
94
+ type_vocab_size=2,
95
+ initializer_range=0.02,
96
+ layer_norm_eps=1e-12,
97
+ pad_token_id=0,
98
+ rotary_value=False,
99
+ use_cache=True,
100
+ rope_theta=10000,
101
+ use_pose=False,
102
+ pose_target_len=None,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
106
+
107
+ self.vocab_size = vocab_size
108
+ self.embedding_size = hidden_size if embedding_size is None else embedding_size
109
+ self.hidden_size = hidden_size
110
+ self.num_hidden_layers = num_hidden_layers
111
+ self.num_attention_heads = num_attention_heads
112
+ self.hidden_act = hidden_act
113
+ self.intermediate_size = intermediate_size
114
+ self.hidden_dropout_prob = hidden_dropout_prob
115
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
116
+ self.max_position_embeddings = max_position_embeddings
117
+ self.type_vocab_size = type_vocab_size
118
+ self.initializer_range = initializer_range
119
+ self.layer_norm_eps = layer_norm_eps
120
+ self.rotary_value = rotary_value
121
+ self.use_cache = use_cache
122
+ self.rope_theta = rope_theta
123
+ self.use_pose = use_pose
124
+ self.pose_target_len = pose_target_len
125
+
126
+
127
+ class E5RopeOnnxConfig(OnnxConfig):
128
+ @property
129
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
130
+ if self.task == "multiple-choice":
131
+ dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
132
+ else:
133
+ dynamic_axis = {0: "batch", 1: "sequence"}
134
+ dynamic_axis = {0: "batch", 1: "sequence"}
135
+ return OrderedDict(
136
+ [
137
+ ("input_ids", dynamic_axis),
138
+ ("attention_mask", dynamic_axis),
139
+ ("token_type_ids", dynamic_axis),
140
+ ]
141
+ )
modeling_e5rope.py ADDED
@@ -0,0 +1,1306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # This file has been modified from the modeling_roformer.py file in the transformers library. The original RoPE implementation has been replaced with the LLaMA style RoPE implementation.
3
+ # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """ PyTorch E5Rope model."""
17
+
18
+
19
+ import math
20
+ import random
21
+ import os
22
+ from typing import Optional, Tuple, Union
23
+
24
+ import numpy as np
25
+ import torch
26
+ import torch.utils.checkpoint
27
+ import xformers.ops as xops
28
+
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.modeling_outputs import (
34
+ BaseModelOutputWithPastAndCrossAttentions,
35
+ CausalLMOutputWithCrossAttentions,
36
+ MaskedLMOutput,
37
+ MultipleChoiceModelOutput,
38
+ QuestionAnsweringModelOutput,
39
+ SequenceClassifierOutput,
40
+ TokenClassifierOutput,
41
+ )
42
+ from transformers.modeling_utils import PreTrainedModel, SequenceSummary
43
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
44
+ from transformers.utils import (
45
+ add_code_sample_docstrings,
46
+ add_start_docstrings,
47
+ add_start_docstrings_to_model_forward,
48
+ logging,
49
+ replace_return_docstrings,
50
+ )
51
+ from .configuration_e5rope import E5RopeConfig
52
+
53
+
54
+ logger = logging.get_logger(__name__)
55
+
56
+
57
+
58
+ class E5RopeRotaryEmbedding(torch.nn.Module):
59
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
60
+ super().__init__()
61
+
62
+ self.dim = dim
63
+ self.max_position_embeddings = max_position_embeddings
64
+ self.base = base
65
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
66
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
67
+
68
+ # Build here to make `torch.jit.trace` work.
69
+ self._set_cos_sin_cache(
70
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
71
+ )
72
+
73
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
74
+ self.max_seq_len_cached = seq_len
75
+ # t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
76
+ t = np.arange(self.max_seq_len_cached, dtype=np.float64)
77
+ t = torch.tensor(t, device=self.inv_freq.device, dtype=torch.float64)
78
+
79
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
80
+ freqs = torch.outer(t, self.inv_freq.to(device=t.device).to(t.dtype))
81
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
82
+ emb = torch.cat((freqs, freqs), dim=-1)
83
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
84
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
85
+
86
+ def forward(self, x, seq_len=None):
87
+ # x: [bs, num_attention_heads, seq_len, head_size]
88
+ if seq_len > self.max_seq_len_cached:
89
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
90
+
91
+ return (
92
+ self.cos_cached[:, :, :, ...].to(dtype=x.dtype),
93
+ self.sin_cached[:, :, :, ...].to(dtype=x.dtype),
94
+ )
95
+
96
+
97
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
98
+ """
99
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
100
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
101
+ """
102
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
103
+ if n_rep == 1:
104
+ return hidden_states
105
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
106
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
107
+
108
+ def rotate_half(x):
109
+ """Rotates half the hidden dims of the input."""
110
+ x1 = x[..., : x.shape[-1] // 2]
111
+ x2 = x[..., x.shape[-1] // 2 :]
112
+ return torch.cat((-x2, x1), dim=-1)
113
+
114
+
115
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
116
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
117
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
118
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
119
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
120
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
121
+ q_embed = (q * cos) + (rotate_half(q) * sin)
122
+ k_embed = (k * cos) + (rotate_half(k) * sin)
123
+ return q_embed, k_embed
124
+
125
+
126
+ def load_tf_weights_in_e5rope(model, config, tf_checkpoint_path):
127
+ """Load tf checkpoints in a pytorch model."""
128
+ try:
129
+ import re
130
+
131
+ import numpy as np
132
+ import tensorflow as tf
133
+ except ImportError:
134
+ logger.error(
135
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
136
+ "https://www.tensorflow.org/install/ for installation instructions."
137
+ )
138
+ raise
139
+ tf_path = os.path.abspath(tf_checkpoint_path)
140
+ logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
141
+ # Load weights from TF model
142
+ init_vars = tf.train.list_variables(tf_path)
143
+ names = []
144
+ arrays = []
145
+ for name, shape in init_vars:
146
+ logger.info(f"Loading TF weight {name} with shape {shape}")
147
+ array = tf.train.load_variable(tf_path, name)
148
+ names.append(name.replace("bert", "e5rope"))
149
+ arrays.append(array)
150
+
151
+ for name, array in zip(names, arrays):
152
+ name = name.split("/")
153
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
154
+ # which are not required for using pretrained model
155
+ if any(
156
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
157
+ for n in name
158
+ ):
159
+ logger.info(f"Skipping {'/'.join(name)}")
160
+ continue
161
+ pointer = model
162
+ for m_name in name:
163
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
164
+ scope_names = re.split(r"_(\d+)", m_name)
165
+ else:
166
+ scope_names = [m_name]
167
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
168
+ pointer = getattr(pointer, "weight")
169
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
170
+ pointer = getattr(pointer, "bias")
171
+ elif scope_names[0] == "output_weights":
172
+ pointer = getattr(pointer, "weight")
173
+ elif scope_names[0] == "squad":
174
+ pointer = getattr(pointer, "classifier")
175
+ else:
176
+ try:
177
+ pointer = getattr(pointer, scope_names[0])
178
+ except AttributeError:
179
+ logger.info(f"Skipping {'/'.join(name)}")
180
+ continue
181
+ if len(scope_names) >= 2:
182
+ num = int(scope_names[1])
183
+ pointer = pointer[num]
184
+ if m_name[-11:] == "_embeddings":
185
+ pointer = getattr(pointer, "weight")
186
+ elif m_name == "kernel":
187
+ array = np.transpose(array)
188
+ try:
189
+ if not pointer.shape == array.shape:
190
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
191
+ except AssertionError as e:
192
+ e.args += (pointer.shape, array.shape)
193
+ raise
194
+ logger.info(f"Initialize PyTorch weight {name}")
195
+ pointer.data = torch.from_numpy(array)
196
+ return model
197
+
198
+
199
+ class E5RopeEmbeddings(nn.Module):
200
+ """Construct the embeddings from word and token_type embeddings."""
201
+
202
+ def __init__(self, config):
203
+ super().__init__()
204
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
205
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
206
+
207
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
208
+ # any TensorFlow checkpoint file
209
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
210
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
211
+
212
+ def forward(self, input_ids=None, token_type_ids=None, inputs_embeds=None):
213
+ if input_ids is not None:
214
+ input_shape = input_ids.size()
215
+ else:
216
+ input_shape = inputs_embeds.size()[:-1]
217
+
218
+ if inputs_embeds is None:
219
+ inputs_embeds = self.word_embeddings(input_ids)
220
+
221
+ if token_type_ids is None:
222
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=inputs_embeds.device)
223
+
224
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
225
+
226
+ embeddings = inputs_embeds + token_type_embeddings
227
+
228
+ embeddings = self.LayerNorm(embeddings)
229
+ embeddings = self.dropout(embeddings)
230
+ return embeddings
231
+
232
+
233
+ class E5RopeSelfAttention(nn.Module):
234
+ def __init__(self, config):
235
+ super().__init__()
236
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
237
+ raise ValueError(
238
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
239
+ f"heads ({config.num_attention_heads})"
240
+ )
241
+
242
+ self.num_attention_heads = config.num_attention_heads
243
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
244
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
245
+
246
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
247
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
248
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
249
+
250
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
251
+ self.is_decoder = config.is_decoder
252
+
253
+ self.config = config
254
+ self.max_position_embeddings = config.max_position_embeddings
255
+ self.rope_theta = config.rope_theta
256
+
257
+ self.rotary_emb = E5RopeRotaryEmbedding(
258
+ self.attention_head_size,
259
+ max_position_embeddings=self.max_position_embeddings,
260
+ base=self.rope_theta,
261
+ )
262
+ # self.forward = self.normal_forward
263
+
264
+
265
+ def transpose_for_scores(self, x):
266
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
267
+ x = x.view(*new_x_shape)
268
+ return x.permute(0, 2, 1, 3)
269
+
270
+
271
+ def forward(
272
+ self,
273
+ hidden_states,
274
+ attention_mask=None,
275
+ position_ids=None,
276
+ head_mask=None,
277
+ encoder_hidden_states=None,
278
+ encoder_attention_mask=None,
279
+ past_key_value=None,
280
+ output_attentions=False,
281
+ ):
282
+ mixed_query_layer = self.query(hidden_states)
283
+ query_layer = self.transpose_for_scores(mixed_query_layer)
284
+ # If this is instantiated as a cross-attention module, the keys
285
+ # and values come from an encoder; the attention mask needs to be
286
+ # such that the encoder's padding tokens are not attended to.
287
+ is_cross_attention = encoder_hidden_states is not None
288
+
289
+ if is_cross_attention and past_key_value is not None:
290
+ # reuse k,v, cross_attentions
291
+ key_layer = past_key_value[0]
292
+ value_layer = past_key_value[1]
293
+ attention_mask = encoder_attention_mask
294
+ elif is_cross_attention:
295
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
296
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
297
+ attention_mask = encoder_attention_mask
298
+ else:
299
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
300
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
301
+
302
+ kv_seq_len = key_layer.shape[-2]
303
+ if past_key_value is not None:
304
+ kv_seq_len += past_key_value[0].shape[-2]
305
+
306
+ cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
307
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
308
+
309
+ if past_key_value is not None:
310
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
311
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
312
+
313
+ if self.is_decoder:
314
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
315
+ # Further calls to cross_attention layer can then reuse all cross-attention
316
+ # key/value_states (first "if" case)
317
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
318
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
319
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
320
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
321
+ past_key_value = (key_layer, value_layer)
322
+
323
+ bsz, n_heads, seq_len, head_dim = query_layer.shape
324
+
325
+ # get each seq len
326
+ tmp_attention_mask = attention_mask.squeeze()
327
+ if tmp_attention_mask.dim() == 1:
328
+ tmp_attention_mask = tmp_attention_mask.unsqueeze(0)
329
+ each_seq_len = torch.sum(tmp_attention_mask == 0, dim=-1)
330
+ original_len = torch.tensor(512)
331
+
332
+ ### attention scaling for better length extrapolation ###
333
+ ### https://arxiv.org/abs/2202.12172 ; https://kexue.fm/archives/8823 ###
334
+ attn_factors = torch.log(each_seq_len) / torch.log(original_len)
335
+ attn_factors = torch.clamp(attn_factors, min=1.0) # Ensure a minimum value of 1
336
+ attn_factors = attn_factors.view(-1, 1, 1, 1)
337
+ query_layer *= attn_factors
338
+
339
+ attention_mask = attention_mask.expand(bsz, n_heads, seq_len, seq_len).to(dtype=query_layer.dtype)
340
+ attn_output = xops.memory_efficient_attention(
341
+ query_layer.transpose(1, 2), key_layer.transpose(1, 2), value_layer.transpose(1, 2),
342
+ attn_bias=attention_mask, p=(self.dropout.p if self.training else 0)
343
+ ).reshape(bsz, seq_len, n_heads * head_dim)
344
+
345
+ if output_attentions is True:
346
+ raise NotImplementedError('output_attentions is not supported for xformers attention')
347
+
348
+ return (attn_output,)
349
+
350
+ def normal_forward(
351
+ self,
352
+ hidden_states,
353
+ attention_mask=None,
354
+ position_ids=None,
355
+ head_mask=None,
356
+ encoder_hidden_states=None,
357
+ encoder_attention_mask=None,
358
+ past_key_value=None,
359
+ output_attentions=False,
360
+ ):
361
+ mixed_query_layer = self.query(hidden_states)
362
+ query_layer = self.transpose_for_scores(mixed_query_layer)
363
+ # If this is instantiated as a cross-attention module, the keys
364
+ # and values come from an encoder; the attention mask needs to be
365
+ # such that the encoder's padding tokens are not attended to.
366
+ is_cross_attention = encoder_hidden_states is not None
367
+
368
+ if is_cross_attention and past_key_value is not None:
369
+ # reuse k,v, cross_attentions
370
+ key_layer = past_key_value[0]
371
+ value_layer = past_key_value[1]
372
+ attention_mask = encoder_attention_mask
373
+ elif is_cross_attention:
374
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
375
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
376
+ attention_mask = encoder_attention_mask
377
+ else:
378
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
379
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
380
+
381
+ kv_seq_len = key_layer.shape[-2]
382
+ if past_key_value is not None:
383
+ kv_seq_len += past_key_value[0].shape[-2]
384
+
385
+ cos, sin = self.rotary_emb(value_layer, seq_len=kv_seq_len)
386
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids)
387
+
388
+ if past_key_value is not None:
389
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
390
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
391
+
392
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
393
+
394
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
395
+ if attention_mask is not None:
396
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
397
+ attention_scores = attention_scores + attention_mask
398
+
399
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
400
+
401
+ # This is actually dropping out entire tokens to attend to, which might
402
+ # seem a bit unusual, but is taken from the original Transformer paper.
403
+ attention_probs = self.dropout(attention_probs)
404
+
405
+ context_layer = torch.matmul(attention_probs, value_layer)
406
+
407
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
408
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
409
+ context_layer = context_layer.view(*new_context_layer_shape)
410
+
411
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
412
+
413
+ if self.is_decoder:
414
+ outputs = outputs + (past_key_value,)
415
+ return outputs
416
+
417
+
418
+ # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->E5Rope
419
+ class E5RopeSelfOutput(nn.Module):
420
+ def __init__(self, config):
421
+ super().__init__()
422
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
423
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
424
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
425
+
426
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
427
+ hidden_states = self.dense(hidden_states)
428
+ hidden_states = self.dropout(hidden_states)
429
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
430
+ return hidden_states
431
+
432
+
433
+ class E5RopeAttention(nn.Module):
434
+ def __init__(self, config):
435
+ super().__init__()
436
+ self.self = E5RopeSelfAttention(config)
437
+ self.output = E5RopeSelfOutput(config)
438
+ self.pruned_heads = set()
439
+
440
+ # Copied from transformers.models.bert.modeling_bert.BertAttention.prune_heads
441
+ def prune_heads(self, heads):
442
+ if len(heads) == 0:
443
+ return
444
+ heads, index = find_pruneable_heads_and_indices(
445
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
446
+ )
447
+
448
+ # Prune linear layers
449
+ self.self.query = prune_linear_layer(self.self.query, index)
450
+ self.self.key = prune_linear_layer(self.self.key, index)
451
+ self.self.value = prune_linear_layer(self.self.value, index)
452
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
453
+
454
+ # Update hyper params and store pruned heads
455
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
456
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
457
+ self.pruned_heads = self.pruned_heads.union(heads)
458
+
459
+ # End Copy
460
+ def forward(
461
+ self,
462
+ hidden_states,
463
+ attention_mask=None,
464
+ position_ids=None,
465
+ head_mask=None,
466
+ encoder_hidden_states=None,
467
+ encoder_attention_mask=None,
468
+ past_key_value=None,
469
+ output_attentions=False,
470
+ ):
471
+ self_outputs = self.self(
472
+ hidden_states,
473
+ attention_mask,
474
+ position_ids,
475
+ head_mask,
476
+ encoder_hidden_states,
477
+ encoder_attention_mask,
478
+ past_key_value,
479
+ output_attentions,
480
+ )
481
+ attention_output = self.output(self_outputs[0], hidden_states)
482
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
483
+ return outputs
484
+
485
+
486
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->E5Rope
487
+ class E5RopeIntermediate(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
491
+ if isinstance(config.hidden_act, str):
492
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
493
+ else:
494
+ self.intermediate_act_fn = config.hidden_act
495
+
496
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
497
+ hidden_states = self.dense(hidden_states)
498
+ hidden_states = self.intermediate_act_fn(hidden_states)
499
+ return hidden_states
500
+
501
+
502
+ # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->E5Rope
503
+ class E5RopeOutput(nn.Module):
504
+ def __init__(self, config):
505
+ super().__init__()
506
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
507
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
508
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
509
+
510
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
511
+ hidden_states = self.dense(hidden_states)
512
+ hidden_states = self.dropout(hidden_states)
513
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
514
+ return hidden_states
515
+
516
+
517
+ class E5RopeLayer(nn.Module):
518
+ def __init__(self, config):
519
+ super().__init__()
520
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
521
+ self.seq_len_dim = 1
522
+ self.attention = E5RopeAttention(config)
523
+ self.is_decoder = config.is_decoder
524
+ self.add_cross_attention = config.add_cross_attention
525
+ if self.add_cross_attention:
526
+ if not self.is_decoder:
527
+ raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
528
+ self.crossattention = E5RopeAttention(config)
529
+ self.intermediate = E5RopeIntermediate(config)
530
+ self.output = E5RopeOutput(config)
531
+
532
+ def forward(
533
+ self,
534
+ hidden_states,
535
+ attention_mask=None,
536
+ position_ids=None,
537
+ head_mask=None,
538
+ encoder_hidden_states=None,
539
+ encoder_attention_mask=None,
540
+ past_key_value=None,
541
+ output_attentions=False,
542
+ ):
543
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
544
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
545
+ self_attention_outputs = self.attention(
546
+ hidden_states,
547
+ attention_mask,
548
+ position_ids,
549
+ head_mask,
550
+ output_attentions=output_attentions,
551
+ past_key_value=self_attn_past_key_value,
552
+ )
553
+ attention_output = self_attention_outputs[0]
554
+
555
+ # if decoder, the last output is tuple of self-attn cache
556
+ if self.is_decoder:
557
+ outputs = self_attention_outputs[1:-1]
558
+ present_key_value = self_attention_outputs[-1]
559
+ else:
560
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
561
+
562
+ cross_attn_present_key_value = None
563
+ if self.is_decoder and encoder_hidden_states is not None:
564
+ if not hasattr(self, "crossattention"):
565
+ raise ValueError(
566
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention "
567
+ "layers by setting `config.add_cross_attention=True`"
568
+ )
569
+
570
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
571
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
572
+ cross_attention_outputs = self.crossattention(
573
+ attention_output,
574
+ attention_mask,
575
+ position_ids,
576
+ head_mask,
577
+ encoder_hidden_states,
578
+ encoder_attention_mask,
579
+ cross_attn_past_key_value,
580
+ output_attentions,
581
+ )
582
+ attention_output = cross_attention_outputs[0]
583
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
584
+
585
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
586
+ cross_attn_present_key_value = cross_attention_outputs[-1]
587
+ present_key_value = present_key_value + cross_attn_present_key_value
588
+
589
+ layer_output = apply_chunking_to_forward(
590
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
591
+ )
592
+ outputs = (layer_output,) + outputs
593
+
594
+ # if decoder, return the attn key/values as the last output
595
+ if self.is_decoder:
596
+ outputs = outputs + (present_key_value,)
597
+
598
+ return outputs
599
+
600
+ def feed_forward_chunk(self, attention_output):
601
+ intermediate_output = self.intermediate(attention_output)
602
+ layer_output = self.output(intermediate_output, attention_output)
603
+ return layer_output
604
+
605
+
606
+ class E5RopeEncoder(nn.Module):
607
+ def __init__(self, config):
608
+ super().__init__()
609
+ self.config = config
610
+ self.layer = nn.ModuleList([E5RopeLayer(config) for _ in range(config.num_hidden_layers)])
611
+ self.gradient_checkpointing = False
612
+
613
+ def forward(
614
+ self,
615
+ hidden_states,
616
+ attention_mask=None,
617
+ position_ids=None,
618
+ head_mask=None,
619
+ encoder_hidden_states=None,
620
+ encoder_attention_mask=None,
621
+ past_key_values=None,
622
+ use_cache=None,
623
+ output_attentions=False,
624
+ output_hidden_states=False,
625
+ return_dict=True,
626
+ ):
627
+ if self.gradient_checkpointing and self.training:
628
+ if use_cache:
629
+ logger.warning_once(
630
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
631
+ )
632
+ use_cache = False
633
+ all_hidden_states = () if output_hidden_states else None
634
+ all_self_attentions = () if output_attentions else None
635
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
636
+
637
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
638
+
639
+ # [sequence_length, embed_size_per_head] -> [batch_size, num_heads, sequence_length, embed_size_per_head]
640
+ # sinusoidal_pos = self.embed_positions(hidden_states.shape[:-1], past_key_values_length)[None, None, :, :]
641
+
642
+ next_decoder_cache = () if use_cache else None
643
+ for i, layer_module in enumerate(self.layer):
644
+ if output_hidden_states:
645
+ all_hidden_states = all_hidden_states + (hidden_states,)
646
+
647
+ layer_head_mask = head_mask[i] if head_mask is not None else None
648
+ past_key_value = past_key_values[i] if past_key_values is not None else None
649
+
650
+ if self.gradient_checkpointing and self.training:
651
+
652
+ def create_custom_forward(module):
653
+ def custom_forward(*inputs):
654
+ return module(*inputs, past_key_value, output_attentions)
655
+
656
+ return custom_forward
657
+
658
+ layer_outputs = torch.utils.checkpoint.checkpoint(
659
+ create_custom_forward(layer_module),
660
+ hidden_states,
661
+ attention_mask,
662
+ position_ids,
663
+ layer_head_mask,
664
+ encoder_hidden_states,
665
+ encoder_attention_mask,
666
+ )
667
+ else:
668
+ layer_outputs = layer_module(
669
+ hidden_states,
670
+ attention_mask,
671
+ position_ids,
672
+ layer_head_mask,
673
+ encoder_hidden_states,
674
+ encoder_attention_mask,
675
+ past_key_value,
676
+ output_attentions,
677
+ )
678
+
679
+ hidden_states = layer_outputs[0]
680
+ if use_cache:
681
+ next_decoder_cache += (layer_outputs[-1],)
682
+ if output_attentions:
683
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
684
+ if self.config.add_cross_attention:
685
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
686
+
687
+ if output_hidden_states:
688
+ all_hidden_states = all_hidden_states + (hidden_states,)
689
+
690
+ if not return_dict:
691
+ return tuple(
692
+ v
693
+ for v in [
694
+ hidden_states,
695
+ next_decoder_cache,
696
+ all_hidden_states,
697
+ all_self_attentions,
698
+ all_cross_attentions,
699
+ ]
700
+ if v is not None
701
+ )
702
+ return BaseModelOutputWithPastAndCrossAttentions(
703
+ last_hidden_state=hidden_states,
704
+ past_key_values=next_decoder_cache,
705
+ hidden_states=all_hidden_states,
706
+ attentions=all_self_attentions,
707
+ cross_attentions=all_cross_attentions,
708
+ )
709
+
710
+
711
+ class E5RopePredictionHeadTransform(nn.Module):
712
+ def __init__(self, config):
713
+ super().__init__()
714
+ self.dense = nn.Linear(config.hidden_size, config.embedding_size)
715
+ if isinstance(config.hidden_act, str):
716
+ self.transform_act_fn = ACT2FN[config.hidden_act]
717
+ else:
718
+ self.transform_act_fn = config.hidden_act
719
+ self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
720
+
721
+ def forward(self, hidden_states):
722
+ hidden_states = self.dense(hidden_states)
723
+ hidden_states = self.transform_act_fn(hidden_states)
724
+ hidden_states = self.LayerNorm(hidden_states)
725
+ return hidden_states
726
+
727
+
728
+ class E5RopeLMPredictionHead(nn.Module):
729
+ def __init__(self, config):
730
+ super().__init__()
731
+ self.transform = E5RopePredictionHeadTransform(config)
732
+
733
+ # The output weights are the same as the input embeddings, but there is
734
+ # an output-only bias for each token.
735
+ self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
736
+
737
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
738
+
739
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
740
+ self.decoder.bias = self.bias
741
+
742
+ def forward(self, hidden_states):
743
+ hidden_states = self.transform(hidden_states)
744
+ hidden_states = self.decoder(hidden_states)
745
+ return hidden_states
746
+
747
+
748
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->E5Rope
749
+ class E5RopeOnlyMLMHead(nn.Module):
750
+ def __init__(self, config):
751
+ super().__init__()
752
+ self.predictions = E5RopeLMPredictionHead(config)
753
+
754
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
755
+ prediction_scores = self.predictions(sequence_output)
756
+ return prediction_scores
757
+
758
+
759
+ class E5RopePreTrainedModel(PreTrainedModel):
760
+ """
761
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
762
+ models.
763
+ """
764
+
765
+ config_class = E5RopeConfig
766
+ load_tf_weights = load_tf_weights_in_e5rope
767
+ base_model_prefix = "e5rope"
768
+ supports_gradient_checkpointing = True
769
+
770
+ def _init_weights(self, module):
771
+ """Initialize the weights"""
772
+ if isinstance(module, nn.Linear):
773
+ # Slightly different from the TF version which uses truncated_normal for initialization
774
+ # cf https://github.com/pytorch/pytorch/pull/5617
775
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
776
+ if module.bias is not None:
777
+ module.bias.data.zero_()
778
+ elif isinstance(module, E5RopeRotaryEmbedding):
779
+ pass
780
+ elif isinstance(module, nn.Embedding):
781
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
782
+ if module.padding_idx is not None:
783
+ module.weight.data[module.padding_idx].zero_()
784
+ elif isinstance(module, nn.LayerNorm):
785
+ module.bias.data.zero_()
786
+ module.weight.data.fill_(1.0)
787
+
788
+ def _set_gradient_checkpointing(self, module, value=False):
789
+ if isinstance(module, E5RopeEncoder):
790
+ module.gradient_checkpointing = value
791
+
792
+
793
+ E5ROPE_START_DOCSTRING = r"""
794
+ This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class. Use
795
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
796
+ behavior.
797
+
798
+ Parameters:
799
+ config ([`E5RopeConfig`]): Model configuration class with all the parameters of the model.
800
+ Initializing with a config file does not load the weights associated with the model, only the
801
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
802
+ """
803
+
804
+ E5ROPE_INPUTS_DOCSTRING = r"""
805
+ Args:
806
+ input_ids (`torch.LongTensor` of shape `({0})`):
807
+ Indices of input sequence tokens in the vocabulary.
808
+
809
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
810
+ [`PreTrainedTokenizer.__call__`] for details.
811
+
812
+ [What are input IDs?](../glossary#input-ids)
813
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
814
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
815
+
816
+ - 1 for tokens that are **not masked**,
817
+ - 0 for tokens that are **masked**.
818
+
819
+ [What are attention masks?](../glossary#attention-mask)
820
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
821
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
822
+ 1]`:
823
+
824
+ - 0 corresponds to a *sentence A* token,
825
+ - 1 corresponds to a *sentence B* token.
826
+
827
+ [What are token type IDs?](../glossary#token-type-ids)
828
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
829
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
830
+
831
+ - 1 indicates the head is **not masked**,
832
+ - 0 indicates the head is **masked**.
833
+
834
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
835
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
836
+ is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
837
+ model's internal embedding lookup matrix.
838
+ output_attentions (`bool`, *optional*):
839
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
840
+ tensors for more detail.
841
+ output_hidden_states (`bool`, *optional*):
842
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
843
+ more detail.
844
+ return_dict (`bool`, *optional*):
845
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
846
+ """
847
+
848
+
849
+ @add_start_docstrings(
850
+ "The bare E5Rope Model transformer outputting raw hidden-states without any specific head on top.",
851
+ E5ROPE_START_DOCSTRING,
852
+ )
853
+ class E5RopeModel(E5RopePreTrainedModel):
854
+ """
855
+
856
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
857
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
858
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
859
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
860
+
861
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
862
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
863
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
864
+ """
865
+
866
+ def __init__(self, config):
867
+ super().__init__(config)
868
+ self.config = config
869
+ self.embeddings = E5RopeEmbeddings(config)
870
+
871
+ if config.embedding_size != config.hidden_size:
872
+ self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
873
+
874
+ self.encoder = E5RopeEncoder(config)
875
+
876
+ # Initialize weights and apply final processing
877
+ self.post_init()
878
+
879
+ def get_input_embeddings(self):
880
+ return self.embeddings.word_embeddings
881
+
882
+ def set_input_embeddings(self, value):
883
+ self.embeddings.word_embeddings = value
884
+
885
+ def _prune_heads(self, heads_to_prune):
886
+ """
887
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
888
+ class PreTrainedModel
889
+ """
890
+ for layer, heads in heads_to_prune.items():
891
+ self.encoder.layer[layer].attention.prune_heads(heads)
892
+
893
+ @add_start_docstrings_to_model_forward(E5ROPE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
894
+ def forward(
895
+ self,
896
+ input_ids: Optional[torch.LongTensor] = None,
897
+ attention_mask: Optional[torch.FloatTensor] = None,
898
+ position_ids: Optional[torch.LongTensor] = None,
899
+ token_type_ids: Optional[torch.LongTensor] = None,
900
+ head_mask: Optional[torch.FloatTensor] = None,
901
+ inputs_embeds: Optional[torch.FloatTensor] = None,
902
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
903
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
904
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
905
+ use_cache: Optional[bool] = None,
906
+ output_attentions: Optional[bool] = None,
907
+ output_hidden_states: Optional[bool] = None,
908
+ return_dict: Optional[bool] = None,
909
+ ) -> Union[BaseModelOutputWithPastAndCrossAttentions, Tuple[torch.Tensor]]:
910
+ r"""
911
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
912
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
913
+ the model is configured as a decoder.
914
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
915
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
916
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
917
+
918
+ - 1 for tokens that are **not masked**,
919
+ - 0 for tokens that are **masked**.
920
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
921
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
922
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
923
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
924
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
925
+ use_cache (`bool`, *optional*):
926
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
927
+ `past_key_values`).
928
+ """
929
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
930
+ output_hidden_states = (
931
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
932
+ )
933
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
934
+
935
+ if self.config.is_decoder:
936
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
937
+ else:
938
+ use_cache = False
939
+
940
+ if input_ids is not None and inputs_embeds is not None:
941
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
942
+ elif input_ids is not None:
943
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
944
+ input_shape = input_ids.size()
945
+ elif inputs_embeds is not None:
946
+ input_shape = inputs_embeds.size()[:-1]
947
+ else:
948
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
949
+
950
+ batch_size, seq_length = input_shape
951
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
952
+
953
+ # past_key_values_length
954
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
955
+
956
+ if attention_mask is None:
957
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
958
+ if token_type_ids is None:
959
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
960
+
961
+ if position_ids is None:
962
+ position_ids = torch.arange(
963
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
964
+ )
965
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
966
+
967
+ ### inserted code for positional skip-wise training ###
968
+ ### https://arxiv.org/abs/2309.10400 ###
969
+ if self.config.use_pose == True and self.training:
970
+ pos_list = []
971
+ for i in range(batch_size):
972
+ bias = random.randint(-seq_length, self.config.pose_target_len)
973
+ bias = min(bias, self.config.pose_target_len - seq_length)
974
+ bias = max(bias, 0)
975
+ pos = torch.arange(
976
+ past_key_values_length, past_key_values_length + seq_length, dtype=torch.long, device=device
977
+ )
978
+ bias_st_ids = random.randint(min(64, seq_length-1), seq_length - 1) # do not skip very short sequences
979
+ pos[bias_st_ids:] += bias
980
+ pos_list.append(pos)
981
+ position_ids = torch.stack(pos_list, dim=0)
982
+
983
+ #######################################################
984
+
985
+ else:
986
+ position_ids = position_ids.view(-1, seq_length).long()
987
+
988
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
989
+ # ourselves in which case we just need to make it broadcastable to all heads.
990
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
991
+
992
+ # If a 2D or 3D attention mask is provided for the cross-attention
993
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
994
+ if self.config.is_decoder and encoder_hidden_states is not None:
995
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
996
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
997
+ if encoder_attention_mask is None:
998
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
999
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1000
+ else:
1001
+ encoder_extended_attention_mask = None
1002
+
1003
+ # Prepare head mask if needed
1004
+ # 1.0 in head_mask indicate we keep the head
1005
+ # attention_probs has shape bsz x n_heads x N x N
1006
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1007
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1008
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1009
+
1010
+ embedding_output = self.embeddings(
1011
+ input_ids=input_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
1012
+ )
1013
+ if hasattr(self, "embeddings_project"):
1014
+ embedding_output = self.embeddings_project(embedding_output)
1015
+
1016
+ encoder_outputs = self.encoder(
1017
+ embedding_output,
1018
+ attention_mask=extended_attention_mask,
1019
+ position_ids=position_ids,
1020
+ head_mask=head_mask,
1021
+ encoder_hidden_states=encoder_hidden_states,
1022
+ encoder_attention_mask=encoder_extended_attention_mask,
1023
+ past_key_values=past_key_values,
1024
+ use_cache=use_cache,
1025
+ output_attentions=output_attentions,
1026
+ output_hidden_states=output_hidden_states,
1027
+ return_dict=return_dict,
1028
+ )
1029
+ sequence_output = encoder_outputs[0]
1030
+
1031
+ if not return_dict:
1032
+ return (sequence_output,) + encoder_outputs[1:]
1033
+
1034
+ return BaseModelOutputWithPastAndCrossAttentions(
1035
+ last_hidden_state=sequence_output,
1036
+ past_key_values=encoder_outputs.past_key_values,
1037
+ hidden_states=encoder_outputs.hidden_states,
1038
+ attentions=encoder_outputs.attentions,
1039
+ cross_attentions=encoder_outputs.cross_attentions,
1040
+ )
1041
+
1042
+
1043
+ @add_start_docstrings("""E5Rope Model with a `language modeling` head on top.""", E5ROPE_START_DOCSTRING)
1044
+ class E5RopeForMaskedLM(E5RopePreTrainedModel):
1045
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1046
+
1047
+ def __init__(self, config):
1048
+ super().__init__(config)
1049
+
1050
+ if config.is_decoder:
1051
+ logger.warning(
1052
+ "If you want to use `E5RopeForMaskedLM` make sure `config.is_decoder=False` for "
1053
+ "bi-directional self-attention."
1054
+ )
1055
+
1056
+ self.e5rope = E5RopeModel(config)
1057
+ self.cls = E5RopeOnlyMLMHead(config)
1058
+
1059
+ # Initialize weights and apply final processing
1060
+ self.post_init()
1061
+
1062
+ def get_output_embeddings(self):
1063
+ return self.cls.predictions.decoder
1064
+
1065
+ def set_output_embeddings(self, new_embeddings):
1066
+ self.cls.predictions.decoder = new_embeddings
1067
+
1068
+ @add_start_docstrings_to_model_forward(E5ROPE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1069
+ def forward(
1070
+ self,
1071
+ input_ids: Optional[torch.LongTensor] = None,
1072
+ attention_mask: Optional[torch.FloatTensor] = None,
1073
+ token_type_ids: Optional[torch.LongTensor] = None,
1074
+ head_mask: Optional[torch.FloatTensor] = None,
1075
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1076
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1077
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1078
+ labels: Optional[torch.LongTensor] = None,
1079
+ output_attentions: Optional[bool] = None,
1080
+ output_hidden_states: Optional[bool] = None,
1081
+ return_dict: Optional[bool] = None,
1082
+ ) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]:
1083
+ r"""
1084
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1085
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1086
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1087
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1088
+ """
1089
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1090
+
1091
+ outputs = self.e5rope(
1092
+ input_ids,
1093
+ attention_mask=attention_mask,
1094
+ token_type_ids=token_type_ids,
1095
+ head_mask=head_mask,
1096
+ inputs_embeds=inputs_embeds,
1097
+ encoder_hidden_states=encoder_hidden_states,
1098
+ encoder_attention_mask=encoder_attention_mask,
1099
+ output_attentions=output_attentions,
1100
+ output_hidden_states=output_hidden_states,
1101
+ return_dict=return_dict,
1102
+ )
1103
+
1104
+ sequence_output = outputs[0]
1105
+ prediction_scores = self.cls(sequence_output)
1106
+
1107
+ masked_lm_loss = None
1108
+ if labels is not None:
1109
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1110
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1111
+
1112
+ if not return_dict:
1113
+ output = (prediction_scores,) + outputs[1:]
1114
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1115
+
1116
+ return MaskedLMOutput(
1117
+ loss=masked_lm_loss,
1118
+ logits=prediction_scores,
1119
+ hidden_states=outputs.hidden_states,
1120
+ attentions=outputs.attentions,
1121
+ )
1122
+
1123
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1124
+ input_shape = input_ids.shape
1125
+ effective_batch_size = input_shape[0]
1126
+
1127
+ # add a dummy token
1128
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1129
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1130
+ dummy_token = torch.full(
1131
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1132
+ )
1133
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1134
+
1135
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1136
+
1137
+
1138
+ @add_start_docstrings(
1139
+ """E5Rope Model with a `language modeling` head on top for CLM fine-tuning.""", E5ROPE_START_DOCSTRING
1140
+ )
1141
+ class E5RopeForCausalLM(E5RopePreTrainedModel):
1142
+ _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
1143
+
1144
+ def __init__(self, config):
1145
+ super().__init__(config)
1146
+
1147
+ if not config.is_decoder:
1148
+ logger.warning("If you want to use `E5RopeForCausalLM` as a standalone, add `is_decoder=True.`")
1149
+
1150
+ self.e5rope = E5RopeModel(config)
1151
+ self.cls = E5RopeOnlyMLMHead(config)
1152
+
1153
+ # Initialize weights and apply final processing
1154
+ self.post_init()
1155
+
1156
+ def get_output_embeddings(self):
1157
+ return self.cls.predictions.decoder
1158
+
1159
+ def set_output_embeddings(self, new_embeddings):
1160
+ self.cls.predictions.decoder = new_embeddings
1161
+
1162
+ @add_start_docstrings_to_model_forward(E5ROPE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1163
+
1164
+ def forward(
1165
+ self,
1166
+ input_ids: Optional[torch.LongTensor] = None,
1167
+ attention_mask: Optional[torch.FloatTensor] = None,
1168
+ token_type_ids: Optional[torch.LongTensor] = None,
1169
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1170
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1171
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1172
+ head_mask: Optional[torch.FloatTensor] = None,
1173
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1174
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1175
+ labels: Optional[torch.LongTensor] = None,
1176
+ use_cache: Optional[bool] = None,
1177
+ output_attentions: Optional[bool] = None,
1178
+ output_hidden_states: Optional[bool] = None,
1179
+ return_dict: Optional[bool] = None,
1180
+ ) -> Union[CausalLMOutputWithCrossAttentions, Tuple[torch.Tensor]]:
1181
+ r"""
1182
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1183
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1184
+ the model is configured as a decoder.
1185
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1186
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1187
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1188
+
1189
+ - 1 for tokens that are **not masked**,
1190
+ - 0 for tokens that are **masked**.
1191
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1192
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1193
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1194
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1195
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1196
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1197
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1198
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1199
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`.
1200
+ use_cache (`bool`, *optional*):
1201
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1202
+ `past_key_values`).
1203
+
1204
+ Returns:
1205
+
1206
+ Example:
1207
+
1208
+ ```python
1209
+ >>> from transformers import AutoTokenizer, E5RopeForCausalLM, E5RopeConfig
1210
+ >>> import torch
1211
+
1212
+ >>> tokenizer = AutoTokenizer.from_pretrained("junnyu/e5rope_chinese_base")
1213
+ >>> config = E5RopeConfig.from_pretrained("junnyu/e5rope_chinese_base")
1214
+ >>> config.is_decoder = True
1215
+ >>> model = E5RopeForCausalLM.from_pretrained("junnyu/e5rope_chinese_base", config=config)
1216
+
1217
+ >>> inputs = tokenizer("今天天气非常好。", return_tensors="pt")
1218
+ >>> outputs = model(**inputs)
1219
+
1220
+ >>> prediction_logits = outputs.logits
1221
+ ```"""
1222
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1223
+
1224
+ outputs = self.e5rope(
1225
+ input_ids,
1226
+ attention_mask=attention_mask,
1227
+ token_type_ids=token_type_ids,
1228
+ head_mask=head_mask,
1229
+ inputs_embeds=inputs_embeds,
1230
+ encoder_hidden_states=encoder_hidden_states,
1231
+ encoder_attention_mask=encoder_attention_mask,
1232
+ past_key_values=past_key_values,
1233
+ use_cache=use_cache,
1234
+ output_attentions=output_attentions,
1235
+ output_hidden_states=output_hidden_states,
1236
+ return_dict=return_dict,
1237
+ )
1238
+
1239
+ sequence_output = outputs[0]
1240
+ prediction_scores = self.cls(sequence_output)
1241
+
1242
+ lm_loss = None
1243
+ if labels is not None:
1244
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1245
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1246
+ labels = labels[:, 1:].contiguous()
1247
+ loss_fct = CrossEntropyLoss()
1248
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1249
+
1250
+ if not return_dict:
1251
+ output = (prediction_scores,) + outputs[1:]
1252
+ return ((lm_loss,) + output) if lm_loss is not None else output
1253
+
1254
+ return CausalLMOutputWithCrossAttentions(
1255
+ loss=lm_loss,
1256
+ logits=prediction_scores,
1257
+ past_key_values=outputs.past_key_values,
1258
+ hidden_states=outputs.hidden_states,
1259
+ attentions=outputs.attentions,
1260
+ cross_attentions=outputs.cross_attentions,
1261
+ )
1262
+
1263
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
1264
+ input_shape = input_ids.shape
1265
+
1266
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1267
+ if attention_mask is None:
1268
+ attention_mask = input_ids.new_ones(input_shape)
1269
+
1270
+ # cut decoder_input_ids if past is used
1271
+ if past_key_values is not None:
1272
+ input_ids = input_ids[:, -1:]
1273
+
1274
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
1275
+
1276
+ def _reorder_cache(self, past_key_values, beam_idx):
1277
+ reordered_past = ()
1278
+ for layer_past in past_key_values:
1279
+ reordered_past += (
1280
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
1281
+ + layer_past[2:],
1282
+ )
1283
+ return reordered_past
1284
+
1285
+
1286
+ class E5RopeClassificationHead(nn.Module):
1287
+ """Head for sentence-level classification tasks."""
1288
+
1289
+ def __init__(self, config):
1290
+ super().__init__()
1291
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1292
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1293
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
1294
+
1295
+ self.config = config
1296
+
1297
+ def forward(self, features, **kwargs):
1298
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
1299
+ x = self.dropout(x)
1300
+ x = self.dense(x)
1301
+ x = ACT2FN[self.config.hidden_act](x)
1302
+ x = self.dropout(x)
1303
+ x = self.out_proj(x)
1304
+ return x
1305
+
1306
+