zifei9 commited on
Commit
02a2c43
·
verified ·
1 Parent(s): 99ed146

Upload modeling_mistral.py

Browse files

Changing
```python
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
```
to
```python
input_shape0, input_shape1 = hidden_states.shape[:-1]
hidden_shape = (input_shape0, input_shape1, -1, self.head_dim)
```

Files changed (1) hide show
  1. modeling_mistral.py +1123 -0
modeling_mistral.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/mistral/modular_mistral.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_mistral.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ from typing import Callable, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ from transformers.activations import ACT2FN
13
+ from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
17
+ from transformers.modeling_outputs import (
18
+ BaseModelOutputWithPast,
19
+ CausalLMOutputWithPast,
20
+ QuestionAnsweringModelOutput,
21
+ SequenceClassifierOutputWithPast,
22
+ TokenClassifierOutput,
23
+ )
24
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
25
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
26
+ from transformers.processing_utils import Unpack
27
+ from transformers.utils import (
28
+ LossKwargs,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ logging,
33
+ replace_return_docstrings,
34
+ )
35
+ from .configuration_mistral import MistralConfig
36
+ from transformers.models.mistral.modeling_mistral import MistralRMSNorm
37
+
38
+
39
+ logger = logging.get_logger(__name__)
40
+
41
+ _CHECKPOINT_FOR_DOC = "mistralai/Mistral-7B-v0.1"
42
+ _CONFIG_FOR_DOC = "MistralConfig"
43
+
44
+
45
+ class MistralMLP(nn.Module):
46
+ def __init__(self, config):
47
+ super().__init__()
48
+ self.config = config
49
+ self.hidden_size = config.hidden_size
50
+ self.intermediate_size = config.intermediate_size
51
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
52
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
53
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
54
+ self.act_fn = ACT2FN[config.hidden_act]
55
+
56
+ def forward(self, x):
57
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
58
+ return down_proj
59
+
60
+
61
+ def rotate_half(x):
62
+ """Rotates half the hidden dims of the input."""
63
+ x1 = x[..., : x.shape[-1] // 2]
64
+ x2 = x[..., x.shape[-1] // 2 :]
65
+ return torch.cat((-x2, x1), dim=-1)
66
+
67
+
68
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
69
+ """Applies Rotary Position Embedding to the query and key tensors.
70
+
71
+ Args:
72
+ q (`torch.Tensor`): The query tensor.
73
+ k (`torch.Tensor`): The key tensor.
74
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
75
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
76
+ position_ids (`torch.Tensor`, *optional*):
77
+ Deprecated and unused.
78
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
79
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
80
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
81
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
82
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
83
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
84
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
85
+ Returns:
86
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
87
+ """
88
+ cos = cos.unsqueeze(unsqueeze_dim)
89
+ sin = sin.unsqueeze(unsqueeze_dim)
90
+ q_embed = (q * cos) + (rotate_half(q) * sin)
91
+ k_embed = (k * cos) + (rotate_half(k) * sin)
92
+ return q_embed, k_embed
93
+
94
+
95
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
96
+ """
97
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
98
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
99
+ """
100
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
101
+ if n_rep == 1:
102
+ return hidden_states
103
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
104
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
105
+
106
+
107
+ def eager_attention_forward(
108
+ module: nn.Module,
109
+ query: torch.Tensor,
110
+ key: torch.Tensor,
111
+ value: torch.Tensor,
112
+ attention_mask: Optional[torch.Tensor],
113
+ scaling: float,
114
+ dropout: float = 0.0,
115
+ **kwargs,
116
+ ):
117
+ key_states = repeat_kv(key, module.num_key_value_groups)
118
+ value_states = repeat_kv(value, module.num_key_value_groups)
119
+
120
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
121
+ if attention_mask is not None:
122
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
123
+ attn_weights = attn_weights + causal_mask
124
+
125
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
126
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
127
+ attn_output = torch.matmul(attn_weights, value_states)
128
+ attn_output = attn_output.transpose(1, 2).contiguous()
129
+
130
+ return attn_output, attn_weights
131
+
132
+
133
+ class MistralAttention(nn.Module):
134
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
135
+
136
+ def __init__(self, config: MistralConfig, layer_idx: int):
137
+ super().__init__()
138
+ self.config = config
139
+ self.layer_idx = layer_idx
140
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
141
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
142
+ self.scaling = self.head_dim**-0.5
143
+ self.attention_dropout = config.attention_dropout
144
+ self.is_causal = True
145
+ self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
146
+ self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
147
+ self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False)
148
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
154
+ attention_mask: Optional[torch.Tensor],
155
+ past_key_value: Optional[Cache] = None,
156
+ cache_position: Optional[torch.LongTensor] = None,
157
+ **kwargs: Unpack[FlashAttentionKwargs],
158
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
159
+ input_shape0, input_shape1 = hidden_states.shape[:-1]
160
+ hidden_shape = (input_shape0,input_shape1, -1, self.head_dim)
161
+
162
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
163
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
164
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
165
+
166
+ cos, sin = position_embeddings
167
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
168
+
169
+ if past_key_value is not None:
170
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
171
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
172
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
173
+
174
+ attention_interface: Callable = eager_attention_forward
175
+ if self.config._attn_implementation != "eager":
176
+ if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
177
+ logger.warning_once(
178
+ "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
179
+ 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
180
+ )
181
+ else:
182
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
183
+
184
+ attn_output, attn_weights = attention_interface(
185
+ self,
186
+ query_states,
187
+ key_states,
188
+ value_states,
189
+ attention_mask,
190
+ dropout=0.0 if not self.training else self.attention_dropout,
191
+ scaling=self.scaling,
192
+ sliding_window=getattr(self.config, "sliding_window", None), # main diff with Llama
193
+ **kwargs,
194
+ )
195
+
196
+ attn_output = attn_output.reshape(input_shape0,input_shape1, -1).contiguous()
197
+ attn_output = self.o_proj(attn_output)
198
+ return attn_output, attn_weights
199
+
200
+
201
+ class MistralDecoderLayer(nn.Module):
202
+ def __init__(self, config: MistralConfig, layer_idx: int):
203
+ super().__init__()
204
+ self.hidden_size = config.hidden_size
205
+ self.self_attn = MistralAttention(config=config, layer_idx=layer_idx)
206
+ self.mlp = MistralMLP(config)
207
+ self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
208
+ self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
209
+
210
+ def forward(
211
+ self,
212
+ hidden_states: torch.Tensor,
213
+ attention_mask: Optional[torch.Tensor] = None,
214
+ position_ids: Optional[torch.LongTensor] = None,
215
+ past_key_value: Optional[Cache] = None,
216
+ output_attentions: Optional[bool] = False,
217
+ use_cache: Optional[bool] = False,
218
+ cache_position: Optional[torch.LongTensor] = None,
219
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
220
+ **kwargs: Unpack[FlashAttentionKwargs],
221
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
222
+ residual = hidden_states
223
+
224
+ hidden_states = self.input_layernorm(hidden_states)
225
+
226
+ # Self Attention
227
+ hidden_states, self_attn_weights = self.self_attn(
228
+ hidden_states=hidden_states,
229
+ attention_mask=attention_mask,
230
+ position_ids=position_ids,
231
+ past_key_value=past_key_value,
232
+ output_attentions=output_attentions,
233
+ use_cache=use_cache,
234
+ cache_position=cache_position,
235
+ position_embeddings=position_embeddings,
236
+ **kwargs,
237
+ )
238
+ hidden_states = residual + hidden_states
239
+
240
+ # Fully Connected
241
+ residual = hidden_states
242
+ hidden_states = self.post_attention_layernorm(hidden_states)
243
+ hidden_states = self.mlp(hidden_states)
244
+ hidden_states = residual + hidden_states
245
+
246
+ outputs = (hidden_states,)
247
+ if output_attentions:
248
+ outputs += (self_attn_weights,)
249
+
250
+ return outputs
251
+
252
+
253
+ class MistralRotaryEmbedding(nn.Module):
254
+ def __init__(self, config: MistralConfig, device=None):
255
+ super().__init__()
256
+ # BC: "rope_type" was originally "type"
257
+ if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
258
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
259
+ else:
260
+ self.rope_type = "default"
261
+ self.max_seq_len_cached = config.max_position_embeddings
262
+ self.original_max_seq_len = config.max_position_embeddings
263
+
264
+ self.config = config
265
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
266
+
267
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
268
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
269
+ self.original_inv_freq = self.inv_freq
270
+
271
+ def _dynamic_frequency_update(self, position_ids, device):
272
+ """
273
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
274
+ 1 - growing beyond the cached sequence length (allow scaling)
275
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
276
+ """
277
+ seq_len = torch.max(position_ids) + 1
278
+ if seq_len > self.max_seq_len_cached: # growth
279
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
280
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
281
+ self.max_seq_len_cached = seq_len
282
+
283
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
284
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
285
+ self.max_seq_len_cached = self.original_max_seq_len
286
+
287
+ @torch.no_grad()
288
+ def forward(self, x, position_ids):
289
+ if "dynamic" in self.rope_type:
290
+ self._dynamic_frequency_update(position_ids, device=x.device)
291
+
292
+ # Core RoPE block
293
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
294
+ position_ids_expanded = position_ids[:, None, :].float()
295
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
296
+ device_type = x.device.type
297
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
298
+ with torch.autocast(device_type=device_type, enabled=False):
299
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
300
+ emb = torch.cat((freqs, freqs), dim=-1)
301
+ cos = emb.cos()
302
+ sin = emb.sin()
303
+
304
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
305
+ cos = cos * self.attention_scaling
306
+ sin = sin * self.attention_scaling
307
+
308
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
309
+
310
+
311
+ MISTRAL_START_DOCSTRING = r"""
312
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
313
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
314
+ etc.)
315
+
316
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
317
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
318
+ and behavior.
319
+
320
+ Parameters:
321
+ config ([`MistralConfig`]):
322
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
323
+ load the weights associated with the model, only the configuration. Check out the
324
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
325
+ """
326
+
327
+
328
+ @add_start_docstrings(
329
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
330
+ MISTRAL_START_DOCSTRING,
331
+ )
332
+ class MistralPreTrainedModel(PreTrainedModel):
333
+ config_class = MistralConfig
334
+ base_model_prefix = "model"
335
+ supports_gradient_checkpointing = True
336
+ _no_split_modules = ["MistralDecoderLayer"]
337
+ _skip_keys_device_placement = ["past_key_values"]
338
+ _supports_flash_attn_2 = True
339
+ _supports_sdpa = True
340
+ _supports_flex_attn = True
341
+ _supports_cache_class = True
342
+ _supports_quantized_cache = True
343
+ _supports_static_cache = True
344
+
345
+ def _init_weights(self, module):
346
+ std = self.config.initializer_range
347
+ if isinstance(module, nn.Linear):
348
+ module.weight.data.normal_(mean=0.0, std=std)
349
+ if module.bias is not None:
350
+ module.bias.data.zero_()
351
+ elif isinstance(module, nn.Embedding):
352
+ module.weight.data.normal_(mean=0.0, std=std)
353
+ if module.padding_idx is not None:
354
+ module.weight.data[module.padding_idx].zero_()
355
+
356
+
357
+ MISTRAL_INPUTS_DOCSTRING = r"""
358
+ Args:
359
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
360
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
361
+ it.
362
+
363
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
364
+ [`PreTrainedTokenizer.__call__`] for details.
365
+
366
+ [What are input IDs?](../glossary#input-ids)
367
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
368
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
369
+
370
+ - 1 for tokens that are **not masked**,
371
+ - 0 for tokens that are **masked**.
372
+
373
+ [What are attention masks?](../glossary#attention-mask)
374
+
375
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
376
+ [`PreTrainedTokenizer.__call__`] for details.
377
+
378
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
379
+ `past_key_values`).
380
+
381
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
382
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
383
+ information on the default strategy.
384
+
385
+ - 1 indicates the head is **not masked**,
386
+ - 0 indicates the head is **masked**.
387
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
388
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
389
+ config.n_positions - 1]`.
390
+
391
+ [What are position IDs?](../glossary#position-ids)
392
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
393
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
394
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
395
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
396
+
397
+ Two formats are allowed:
398
+ - a [`~cache_utils.Cache`] instance, see our
399
+ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
400
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
401
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
402
+ cache format.
403
+
404
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
405
+ legacy cache format will be returned.
406
+
407
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
408
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
409
+ of shape `(batch_size, sequence_length)`.
410
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
411
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
412
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
413
+ model's internal embedding lookup matrix.
414
+ use_cache (`bool`, *optional*):
415
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
416
+ `past_key_values`).
417
+ output_attentions (`bool`, *optional*):
418
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
419
+ tensors for more detail.
420
+ output_hidden_states (`bool`, *optional*):
421
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
422
+ more detail.
423
+ return_dict (`bool`, *optional*):
424
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
425
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
426
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
427
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
428
+ the complete sequence length.
429
+ """
430
+
431
+
432
+ @add_start_docstrings(
433
+ "The bare Mistral Model outputting raw hidden-states without any specific head on top.",
434
+ MISTRAL_START_DOCSTRING,
435
+ )
436
+ class MistralModel(MistralPreTrainedModel):
437
+ """
438
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MistralDecoderLayer`]
439
+
440
+ Args:
441
+ config: MistralConfig
442
+ """
443
+
444
+ def __init__(self, config: MistralConfig):
445
+ super().__init__(config)
446
+ self.padding_idx = config.pad_token_id
447
+ self.vocab_size = config.vocab_size
448
+
449
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
450
+ self.layers = nn.ModuleList(
451
+ [MistralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
452
+ )
453
+ self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
454
+ self.rotary_emb = MistralRotaryEmbedding(config=config)
455
+ self.gradient_checkpointing = False
456
+
457
+ # Initialize weights and apply final processing
458
+ self.post_init()
459
+
460
+ def get_input_embeddings(self):
461
+ return self.embed_tokens
462
+
463
+ def set_input_embeddings(self, value):
464
+ self.embed_tokens = value
465
+
466
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
467
+ def forward(
468
+ self,
469
+ input_ids: torch.LongTensor = None,
470
+ attention_mask: Optional[torch.Tensor] = None,
471
+ position_ids: Optional[torch.LongTensor] = None,
472
+ past_key_values: Optional[Cache] = None,
473
+ inputs_embeds: Optional[torch.FloatTensor] = None,
474
+ use_cache: Optional[bool] = None,
475
+ output_attentions: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ return_dict: Optional[bool] = None,
478
+ cache_position: Optional[torch.LongTensor] = None,
479
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
480
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
481
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
482
+ output_hidden_states = (
483
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
484
+ )
485
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
486
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
487
+
488
+ if (input_ids is None) ^ (inputs_embeds is not None):
489
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
490
+
491
+ if self.gradient_checkpointing and self.training and use_cache:
492
+ logger.warning_once(
493
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
494
+ )
495
+ use_cache = False
496
+
497
+ if inputs_embeds is None:
498
+ inputs_embeds = self.embed_tokens(input_ids)
499
+
500
+ if use_cache and past_key_values is None:
501
+ past_key_values = DynamicCache()
502
+
503
+ if cache_position is None:
504
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
505
+ cache_position = torch.arange(
506
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
507
+ )
508
+
509
+ if position_ids is None:
510
+ position_ids = cache_position.unsqueeze(0)
511
+
512
+ causal_mask = self._update_causal_mask(
513
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
514
+ )
515
+
516
+ hidden_states = inputs_embeds
517
+
518
+ # create position embeddings to be shared across the decoder layers
519
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
520
+
521
+ # decoder layers
522
+ all_hidden_states = () if output_hidden_states else None
523
+ all_self_attns = () if output_attentions else None
524
+
525
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
526
+ if output_hidden_states:
527
+ all_hidden_states += (hidden_states,)
528
+
529
+ if self.gradient_checkpointing and self.training:
530
+ layer_outputs = self._gradient_checkpointing_func(
531
+ decoder_layer.__call__,
532
+ hidden_states,
533
+ causal_mask,
534
+ position_ids,
535
+ past_key_values,
536
+ output_attentions,
537
+ use_cache,
538
+ cache_position,
539
+ position_embeddings,
540
+ )
541
+ else:
542
+ layer_outputs = decoder_layer(
543
+ hidden_states,
544
+ attention_mask=causal_mask,
545
+ position_ids=position_ids,
546
+ past_key_value=past_key_values,
547
+ output_attentions=output_attentions,
548
+ use_cache=use_cache,
549
+ cache_position=cache_position,
550
+ position_embeddings=position_embeddings,
551
+ **flash_attn_kwargs,
552
+ )
553
+
554
+ hidden_states = layer_outputs[0]
555
+
556
+ if output_attentions:
557
+ all_self_attns += (layer_outputs[1],)
558
+
559
+ hidden_states = self.norm(hidden_states)
560
+
561
+ # add hidden states from the last decoder layer
562
+ if output_hidden_states:
563
+ all_hidden_states += (hidden_states,)
564
+
565
+ output = BaseModelOutputWithPast(
566
+ last_hidden_state=hidden_states,
567
+ past_key_values=past_key_values if use_cache else None,
568
+ hidden_states=all_hidden_states,
569
+ attentions=all_self_attns,
570
+ )
571
+ return output if return_dict else output.to_tuple()
572
+
573
+ def _update_causal_mask(
574
+ self,
575
+ attention_mask: torch.Tensor,
576
+ input_tensor: torch.Tensor,
577
+ cache_position: torch.Tensor,
578
+ past_key_values: Cache,
579
+ output_attentions: bool,
580
+ ):
581
+ if self.config._attn_implementation == "flash_attention_2":
582
+ if attention_mask is not None and past_key_values is not None:
583
+ is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
584
+ if is_padding_right:
585
+ raise ValueError(
586
+ "You are attempting to perform batched generation with padding_side='right'"
587
+ " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
588
+ " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
589
+ )
590
+ if attention_mask is not None and 0.0 in attention_mask:
591
+ return attention_mask
592
+ return None
593
+
594
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
595
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
596
+ # to infer the attention mask.
597
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
598
+ using_static_cache = isinstance(past_key_values, StaticCache)
599
+ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
600
+
601
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
602
+ if (
603
+ self.config._attn_implementation == "sdpa"
604
+ and not (using_static_cache or using_sliding_window_cache)
605
+ and not output_attentions
606
+ ):
607
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
608
+ attention_mask,
609
+ inputs_embeds=input_tensor,
610
+ past_key_values_length=past_seen_tokens,
611
+ sliding_window=self.config.sliding_window,
612
+ is_training=self.training,
613
+ ):
614
+ return None
615
+
616
+ dtype, device = input_tensor.dtype, input_tensor.device
617
+ min_dtype = torch.finfo(dtype).min
618
+ sequence_length = input_tensor.shape[1]
619
+ # SlidingWindowCache or StaticCache
620
+ if using_sliding_window_cache or using_static_cache:
621
+ target_length = past_key_values.get_max_cache_shape()
622
+ # DynamicCache or no cache
623
+ else:
624
+ target_length = (
625
+ attention_mask.shape[-1]
626
+ if isinstance(attention_mask, torch.Tensor)
627
+ else past_seen_tokens + sequence_length + 1
628
+ )
629
+
630
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
631
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
632
+ attention_mask,
633
+ sequence_length=sequence_length,
634
+ target_length=target_length,
635
+ dtype=dtype,
636
+ device=device,
637
+ cache_position=cache_position,
638
+ batch_size=input_tensor.shape[0],
639
+ config=self.config,
640
+ past_key_values=past_key_values,
641
+ )
642
+
643
+ if (
644
+ self.config._attn_implementation == "sdpa"
645
+ and attention_mask is not None
646
+ and attention_mask.device.type == "cuda"
647
+ and not output_attentions
648
+ ):
649
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
650
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
651
+ # Details: https://github.com/pytorch/pytorch/issues/110213
652
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
653
+
654
+ return causal_mask
655
+
656
+ @staticmethod
657
+ def _prepare_4d_causal_attention_mask_with_cache_position(
658
+ attention_mask: torch.Tensor,
659
+ sequence_length: int,
660
+ target_length: int,
661
+ dtype: torch.dtype,
662
+ device: torch.device,
663
+ cache_position: torch.Tensor,
664
+ batch_size: int,
665
+ config: MistralConfig,
666
+ past_key_values: Cache,
667
+ ):
668
+ """
669
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
670
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
671
+
672
+ Args:
673
+ attention_mask (`torch.Tensor`):
674
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
675
+ sequence_length (`int`):
676
+ The sequence length being processed.
677
+ target_length (`int`):
678
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
679
+ dtype (`torch.dtype`):
680
+ The dtype to use for the 4D attention mask.
681
+ device (`torch.device`):
682
+ The device to plcae the 4D attention mask on.
683
+ cache_position (`torch.Tensor`):
684
+ Indices depicting the position of the input sequence tokens in the sequence.
685
+ batch_size (`torch.Tensor`):
686
+ Batch size.
687
+ config (`MistralConfig`):
688
+ The model's configuration class
689
+ past_key_values (`Cache`):
690
+ The cache class that is being used currently to generate
691
+ """
692
+ if attention_mask is not None and attention_mask.dim() == 4:
693
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
694
+ causal_mask = attention_mask
695
+ else:
696
+ min_dtype = torch.finfo(dtype).min
697
+ causal_mask = torch.full(
698
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
699
+ )
700
+ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
701
+ if config.sliding_window is not None:
702
+ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
703
+ # the check is needed to verify is current checkpoint was trained with sliding window or not
704
+ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
705
+ sliding_attend_mask = torch.arange(target_length, device=device) <= (
706
+ cache_position.reshape(-1, 1) - config.sliding_window
707
+ )
708
+ diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
709
+ causal_mask *= diagonal_attend_mask
710
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
711
+ if attention_mask is not None:
712
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
713
+ if attention_mask.shape[-1] > target_length:
714
+ attention_mask = attention_mask[:, :target_length]
715
+ mask_length = attention_mask.shape[-1]
716
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
717
+ padding_mask = padding_mask == 0
718
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
719
+ padding_mask, min_dtype
720
+ )
721
+ return causal_mask
722
+
723
+
724
+ class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
725
+
726
+
727
+ class MistralForCausalLM(MistralPreTrainedModel, GenerationMixin):
728
+ _tied_weights_keys = ["lm_head.weight"]
729
+ _tp_plan = {"lm_head": "colwise_rep"}
730
+
731
+ def __init__(self, config):
732
+ super().__init__(config)
733
+ self.model = MistralModel(config)
734
+ self.vocab_size = config.vocab_size
735
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
736
+
737
+ # Initialize weights and apply final processing
738
+ self.post_init()
739
+
740
+ def get_input_embeddings(self):
741
+ return self.model.embed_tokens
742
+
743
+ def set_input_embeddings(self, value):
744
+ self.model.embed_tokens = value
745
+
746
+ def get_output_embeddings(self):
747
+ return self.lm_head
748
+
749
+ def set_output_embeddings(self, new_embeddings):
750
+ self.lm_head = new_embeddings
751
+
752
+ def set_decoder(self, decoder):
753
+ self.model = decoder
754
+
755
+ def get_decoder(self):
756
+ return self.model
757
+
758
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
759
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
760
+ def forward(
761
+ self,
762
+ input_ids: torch.LongTensor = None,
763
+ attention_mask: Optional[torch.Tensor] = None,
764
+ position_ids: Optional[torch.LongTensor] = None,
765
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
766
+ inputs_embeds: Optional[torch.FloatTensor] = None,
767
+ labels: Optional[torch.LongTensor] = None,
768
+ use_cache: Optional[bool] = None,
769
+ output_attentions: Optional[bool] = None,
770
+ output_hidden_states: Optional[bool] = None,
771
+ return_dict: Optional[bool] = None,
772
+ cache_position: Optional[torch.LongTensor] = None,
773
+ num_logits_to_keep: int = 0,
774
+ **kwargs: Unpack[KwargsForCausalLM],
775
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
776
+ r"""
777
+ Args:
778
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
779
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
780
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
781
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
782
+
783
+ num_logits_to_keep (`int`, *optional*):
784
+ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
785
+ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
786
+ token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
787
+
788
+ Returns:
789
+
790
+ Example:
791
+
792
+ ```python
793
+ >>> from transformers import AutoTokenizer, MistralForCausalLM
794
+
795
+ >>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
796
+ >>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
797
+
798
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
799
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
800
+
801
+ >>> # Generate
802
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
803
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
804
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
805
+ ```"""
806
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
807
+ output_hidden_states = (
808
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
809
+ )
810
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
811
+
812
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
813
+ outputs = self.model(
814
+ input_ids=input_ids,
815
+ attention_mask=attention_mask,
816
+ position_ids=position_ids,
817
+ past_key_values=past_key_values,
818
+ inputs_embeds=inputs_embeds,
819
+ use_cache=use_cache,
820
+ output_attentions=output_attentions,
821
+ output_hidden_states=output_hidden_states,
822
+ return_dict=return_dict,
823
+ cache_position=cache_position,
824
+ **kwargs,
825
+ )
826
+
827
+ hidden_states = outputs[0]
828
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
829
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
830
+
831
+ loss = None
832
+ if labels is not None:
833
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
834
+
835
+ if not return_dict:
836
+ output = (logits,) + outputs[1:]
837
+ return (loss,) + output if loss is not None else output
838
+
839
+ return CausalLMOutputWithPast(
840
+ loss=loss,
841
+ logits=logits,
842
+ past_key_values=outputs.past_key_values,
843
+ hidden_states=outputs.hidden_states,
844
+ attentions=outputs.attentions,
845
+ )
846
+
847
+
848
+ @add_start_docstrings(
849
+ """
850
+ The Mistral Model transformer with a token classification head on top (a linear layer on top of the hidden-states
851
+ output) e.g. for Named-Entity-Recognition (NER) tasks.
852
+ """,
853
+ MISTRAL_START_DOCSTRING,
854
+ )
855
+ class MistralForTokenClassification(MistralPreTrainedModel):
856
+ def __init__(self, config):
857
+ super().__init__(config)
858
+ self.num_labels = config.num_labels
859
+ self.model = MistralModel(config)
860
+ if getattr(config, "classifier_dropout", None) is not None:
861
+ classifier_dropout = config.classifier_dropout
862
+ elif getattr(config, "hidden_dropout", None) is not None:
863
+ classifier_dropout = config.hidden_dropout
864
+ else:
865
+ classifier_dropout = 0.1
866
+ self.dropout = nn.Dropout(classifier_dropout)
867
+ self.score = nn.Linear(config.hidden_size, config.num_labels)
868
+
869
+ # Initialize weights and apply final processing
870
+ self.post_init()
871
+
872
+ def get_input_embeddings(self):
873
+ return self.model.embed_tokens
874
+
875
+ def set_input_embeddings(self, value):
876
+ self.model.embed_tokens = value
877
+
878
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
879
+ @add_code_sample_docstrings(
880
+ checkpoint=_CHECKPOINT_FOR_DOC,
881
+ output_type=TokenClassifierOutput,
882
+ config_class=_CONFIG_FOR_DOC,
883
+ )
884
+ def forward(
885
+ self,
886
+ input_ids: Optional[torch.LongTensor] = None,
887
+ attention_mask: Optional[torch.Tensor] = None,
888
+ position_ids: Optional[torch.LongTensor] = None,
889
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
890
+ inputs_embeds: Optional[torch.FloatTensor] = None,
891
+ labels: Optional[torch.LongTensor] = None,
892
+ use_cache: Optional[bool] = None,
893
+ output_attentions: Optional[bool] = None,
894
+ output_hidden_states: Optional[bool] = None,
895
+ return_dict: Optional[bool] = None,
896
+ ) -> Union[Tuple, TokenClassifierOutput]:
897
+ r"""
898
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
899
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
900
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
901
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
902
+ """
903
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
904
+
905
+ outputs = self.model(
906
+ input_ids,
907
+ attention_mask=attention_mask,
908
+ position_ids=position_ids,
909
+ past_key_values=past_key_values,
910
+ inputs_embeds=inputs_embeds,
911
+ use_cache=use_cache,
912
+ output_attentions=output_attentions,
913
+ output_hidden_states=output_hidden_states,
914
+ return_dict=return_dict,
915
+ )
916
+ sequence_output = outputs[0]
917
+ sequence_output = self.dropout(sequence_output)
918
+ logits = self.score(sequence_output)
919
+
920
+ loss = None
921
+ if labels is not None:
922
+ loss = self.loss_function(logits, labels, self.config)
923
+
924
+ if not return_dict:
925
+ output = (logits,) + outputs[2:]
926
+ return ((loss,) + output) if loss is not None else output
927
+
928
+ return TokenClassifierOutput(
929
+ loss=loss,
930
+ logits=logits,
931
+ hidden_states=outputs.hidden_states,
932
+ attentions=outputs.attentions,
933
+ )
934
+
935
+
936
+ @add_start_docstrings(
937
+ """
938
+ The Mistral Model transformer with a sequence classification head on top (linear layer).
939
+
940
+ [`MistralForSequenceClassification`] uses the last token in order to do the classification, as other causal models
941
+ (e.g. GPT-2) do.
942
+
943
+ Since it does classification on the last token, it requires to know the position of the last token. If a
944
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
945
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
946
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
947
+ each row of the batch).
948
+ """,
949
+ MISTRAL_START_DOCSTRING,
950
+ )
951
+ class MistralForSequenceClassification(MistralPreTrainedModel):
952
+ def __init__(self, config):
953
+ super().__init__(config)
954
+ self.num_labels = config.num_labels
955
+ self.model = MistralModel(config)
956
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
957
+
958
+ # Initialize weights and apply final processing
959
+ self.post_init()
960
+
961
+ def get_input_embeddings(self):
962
+ return self.model.embed_tokens
963
+
964
+ def set_input_embeddings(self, value):
965
+ self.model.embed_tokens = value
966
+
967
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
968
+ def forward(
969
+ self,
970
+ input_ids: Optional[torch.LongTensor] = None,
971
+ attention_mask: Optional[torch.Tensor] = None,
972
+ position_ids: Optional[torch.LongTensor] = None,
973
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
974
+ inputs_embeds: Optional[torch.FloatTensor] = None,
975
+ labels: Optional[torch.LongTensor] = None,
976
+ use_cache: Optional[bool] = None,
977
+ output_attentions: Optional[bool] = None,
978
+ output_hidden_states: Optional[bool] = None,
979
+ return_dict: Optional[bool] = None,
980
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
981
+ r"""
982
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
983
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
984
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
985
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
986
+ """
987
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
988
+
989
+ transformer_outputs = self.model(
990
+ input_ids,
991
+ attention_mask=attention_mask,
992
+ position_ids=position_ids,
993
+ past_key_values=past_key_values,
994
+ inputs_embeds=inputs_embeds,
995
+ use_cache=use_cache,
996
+ output_attentions=output_attentions,
997
+ output_hidden_states=output_hidden_states,
998
+ return_dict=return_dict,
999
+ )
1000
+ hidden_states = transformer_outputs[0]
1001
+ logits = self.score(hidden_states)
1002
+
1003
+ if input_ids is not None:
1004
+ batch_size = input_ids.shape[0]
1005
+ else:
1006
+ batch_size = inputs_embeds.shape[0]
1007
+
1008
+ if self.config.pad_token_id is None and batch_size != 1:
1009
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1010
+ if self.config.pad_token_id is None:
1011
+ sequence_lengths = -1
1012
+ else:
1013
+ if input_ids is not None:
1014
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1015
+ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
1016
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1017
+ sequence_lengths = sequence_lengths.to(logits.device)
1018
+ else:
1019
+ sequence_lengths = -1
1020
+
1021
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1022
+
1023
+ loss = None
1024
+ if labels is not None:
1025
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
1026
+
1027
+ if not return_dict:
1028
+ output = (pooled_logits,) + transformer_outputs[1:]
1029
+ return ((loss,) + output) if loss is not None else output
1030
+
1031
+ return SequenceClassifierOutputWithPast(
1032
+ loss=loss,
1033
+ logits=pooled_logits,
1034
+ past_key_values=transformer_outputs.past_key_values,
1035
+ hidden_states=transformer_outputs.hidden_states,
1036
+ attentions=transformer_outputs.attentions,
1037
+ )
1038
+
1039
+
1040
+ @add_start_docstrings(
1041
+ """
1042
+ The Mistral Model transformer with a span classification head on top for extractive question-answering tasks like
1043
+ SQuAD (a linear layer on top of the hidden-states output to compute `span start logits` and `span end logits`).
1044
+ """,
1045
+ MISTRAL_START_DOCSTRING,
1046
+ )
1047
+ class MistralForQuestionAnswering(MistralPreTrainedModel):
1048
+ base_model_prefix = "model"
1049
+
1050
+ def __init__(self, config):
1051
+ super().__init__(config)
1052
+ self.qa_outputs = nn.Linear(config.hidden_size, 2)
1053
+ self.model = MistralModel(config) # diff with Llama: transformer->model
1054
+
1055
+ # Initialize weights and apply final processing
1056
+ self.post_init()
1057
+
1058
+ def get_input_embeddings(self):
1059
+ return self.model.embed_tokens
1060
+
1061
+ def set_input_embeddings(self, value):
1062
+ self.model.embed_tokens = value
1063
+
1064
+ @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING)
1065
+ def forward(
1066
+ self,
1067
+ input_ids: Optional[torch.LongTensor] = None,
1068
+ attention_mask: Optional[torch.FloatTensor] = None,
1069
+ position_ids: Optional[torch.LongTensor] = None,
1070
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
1071
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1072
+ start_positions: Optional[torch.LongTensor] = None,
1073
+ end_positions: Optional[torch.LongTensor] = None,
1074
+ output_attentions: Optional[bool] = None,
1075
+ output_hidden_states: Optional[bool] = None,
1076
+ return_dict: Optional[bool] = None,
1077
+ **kwargs,
1078
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1079
+ r"""
1080
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1081
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1082
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1083
+ are not taken into account for computing the loss.
1084
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1085
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1086
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1087
+ are not taken into account for computing the loss.
1088
+ """
1089
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1090
+
1091
+ outputs = self.model(
1092
+ input_ids,
1093
+ attention_mask=attention_mask,
1094
+ position_ids=position_ids,
1095
+ past_key_values=past_key_values,
1096
+ inputs_embeds=inputs_embeds,
1097
+ output_attentions=output_attentions,
1098
+ output_hidden_states=output_hidden_states,
1099
+ return_dict=return_dict,
1100
+ )
1101
+
1102
+ sequence_output = outputs[0]
1103
+
1104
+ logits = self.qa_outputs(sequence_output)
1105
+ start_logits, end_logits = logits.split(1, dim=-1)
1106
+ start_logits = start_logits.squeeze(-1).contiguous()
1107
+ end_logits = end_logits.squeeze(-1).contiguous()
1108
+
1109
+ loss = None
1110
+ if start_positions is not None and end_positions is not None:
1111
+ loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
1112
+
1113
+ if not return_dict:
1114
+ output = (start_logits, end_logits) + outputs[2:]
1115
+ return ((loss,) + output) if loss is not None else output
1116
+
1117
+ return QuestionAnsweringModelOutput(
1118
+ loss=loss,
1119
+ start_logits=start_logits,
1120
+ end_logits=end_logits,
1121
+ hidden_states=outputs.hidden_states,
1122
+ attentions=outputs.attentions,
1123
+ )