GrennKren commited on
Commit
407ac25
1 Parent(s): 1ae9363

Upload modeling_exaone.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_exaone.py +1394 -0
modeling_exaone.py ADDED
@@ -0,0 +1,1394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The LG AI Research EXAONE Lab.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
6
+ # and OPT implementations in this library. It has been modified from its
7
+ # original forms to accommodate minor architectural differences compared
8
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ """LG AI Research EXAONE Lab"""
22
+
23
+ import math
24
+ from typing import Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from packaging import version
29
+ from torch import nn
30
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.cache_utils import Cache, DynamicCache, StaticCache
34
+ from transformers.generation import GenerationMixin
35
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
36
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
37
+ from transformers.modeling_outputs import (
38
+ BaseModelOutputWithPast,
39
+ BaseModelOutputWithPastAndCrossAttentions,
40
+ CausalLMOutputWithPast,
41
+ QuestionAnsweringModelOutput,
42
+ SequenceClassifierOutputWithPast,
43
+ )
44
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
45
+ from transformers.modeling_utils import PreTrainedModel
46
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
47
+ from transformers.utils import (
48
+ add_code_sample_docstrings,
49
+ add_start_docstrings,
50
+ add_start_docstrings_to_model_forward,
51
+ is_flash_attn_2_available,
52
+ logging,
53
+ )
54
+ from .configuration_exaone import ExaoneConfig
55
+
56
+
57
+ if is_flash_attn_2_available():
58
+ try:
59
+ import flash_attn
60
+
61
+ if version.parse(flash_attn.__version__) > version.parse("2.4.2"):
62
+ from flash_attn.ops.triton.layer_norm import rms_norm_fn
63
+ else:
64
+ from flash_attn.ops.triton.layernorm import rms_norm_fn
65
+ except ImportError:
66
+ pass
67
+
68
+
69
+ logger = logging.get_logger(__name__)
70
+
71
+ _CHECKPOINT_FOR_DOC = "exaone"
72
+ _CONFIG_FOR_DOC = "ExaoneConfig"
73
+
74
+ EXAONE_PRETRAINED_MODEL_ARCHIVE_LIST = [
75
+ "exaone",
76
+ ]
77
+
78
+
79
+ @torch.jit.script
80
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
81
+ """
82
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
83
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
84
+ """
85
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
86
+ if n_rep == 1:
87
+ return hidden_states
88
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
89
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
90
+
91
+
92
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
93
+ """Applies Rotary Position Embedding to the query and key tensors.
94
+
95
+ Args:
96
+ q (`torch.Tensor`): The query tensor.
97
+ k (`torch.Tensor`): The key tensor.
98
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
99
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
100
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
101
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
102
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
103
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
104
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
105
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
106
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
107
+ Returns:
108
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
109
+ """
110
+ cos = cos.unsqueeze(unsqueeze_dim)
111
+ sin = sin.unsqueeze(unsqueeze_dim)
112
+ q_embed = (q * cos) + (rotate_half(q) * sin)
113
+ k_embed = (k * cos) + (rotate_half(k) * sin)
114
+ return q_embed, k_embed
115
+
116
+
117
+ def rotate_half(x):
118
+ """Rotates half the hidden dims of the input."""
119
+ x1 = x[..., : x.shape[-1] // 2]
120
+ x2 = x[..., x.shape[-1] // 2 :]
121
+ return torch.cat((-x2, x1), dim=-1)
122
+
123
+
124
+ def _prepare_4d_causal_attention_mask_with_cache_position(
125
+ attention_mask: torch.Tensor,
126
+ sequence_length: int,
127
+ target_length: int,
128
+ dtype: torch.dtype,
129
+ device: torch.device,
130
+ min_dtype: float,
131
+ cache_position: torch.Tensor,
132
+ batch_size: int,
133
+ ):
134
+ """
135
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
136
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
137
+
138
+ Args:
139
+ attention_mask (`torch.Tensor`):
140
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
141
+ sequence_length (`int`):
142
+ The sequence length being processed.
143
+ target_length (`int`):
144
+ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
145
+ dtype (`torch.dtype`):
146
+ The dtype to use for the 4D attention mask.
147
+ device (`torch.device`):
148
+ The device to plcae the 4D attention mask on.
149
+ min_dtype (`float`):
150
+ The minimum value representable with the dtype `dtype`.
151
+ cache_position (`torch.Tensor`):
152
+ Indices depicting the position of the input sequence tokens in the sequence.
153
+ batch_size (`torch.Tensor`):
154
+ Batch size.
155
+ """
156
+ if attention_mask is not None and attention_mask.dim() == 4:
157
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
158
+ causal_mask = attention_mask
159
+ else:
160
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
161
+ if sequence_length != 1:
162
+ causal_mask = torch.triu(causal_mask, diagonal=1)
163
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
164
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
165
+ if attention_mask is not None:
166
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
167
+ mask_length = attention_mask.shape[-1]
168
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
169
+ padding_mask = padding_mask == 0
170
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
171
+ padding_mask, min_dtype
172
+ )
173
+
174
+ return causal_mask
175
+
176
+
177
+ class ExaoneRMSNorm(torch.nn.Module):
178
+ def __init__(self, hidden_size, eps=1e-6):
179
+ super().__init__()
180
+ self.eps = eps
181
+ self.weight = torch.nn.Parameter(torch.ones(hidden_size))
182
+
183
+ def forward(self, hidden_states):
184
+ input_dtype = hidden_states.dtype
185
+ hidden_states = hidden_states.to(torch.float32)
186
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
187
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
188
+ return self.weight * hidden_states.to(input_dtype)
189
+
190
+
191
+ class ExaoneTritonRMSNorm(torch.nn.Module):
192
+ def __init__(
193
+ self,
194
+ hidden_size: int = 0,
195
+ eps: float = 1e-5,
196
+ ):
197
+ super().__init__()
198
+ self.eps = eps
199
+ self.drop = None
200
+ self.weight = torch.nn.Parameter(torch.empty(hidden_size))
201
+ self.register_parameter("bias", None)
202
+ self.reset_parameters()
203
+
204
+ def reset_parameters(self):
205
+ torch.nn.init.ones_(self.weight)
206
+
207
+ def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
208
+ return rms_norm_fn(
209
+ x,
210
+ self.weight,
211
+ self.bias,
212
+ residual=residual,
213
+ eps=self.eps,
214
+ dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
215
+ prenorm=prenorm,
216
+ residual_in_fp32=residual_in_fp32,
217
+ )
218
+
219
+
220
+ ALL_LAYERNORM_LAYERS.append(ExaoneRMSNorm)
221
+ ALL_LAYERNORM_LAYERS.append(ExaoneTritonRMSNorm)
222
+
223
+
224
+ class ExaoneRotaryEmbedding(nn.Module):
225
+ def __init__(self, config: ExaoneConfig, device=None):
226
+ super().__init__()
227
+ if config.rope_scaling is not None:
228
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
229
+ else:
230
+ self.rope_type = "default"
231
+ self.rope_theta = config.rope_theta
232
+ self.max_seq_len = config.max_position_embeddings
233
+ self.original_max_seq_len = config.max_position_embeddings
234
+
235
+ self.config = config
236
+ if self.rope_type not in ROPE_INIT_FUNCTIONS:
237
+ raise KeyError(f"The EXAONE model does not support RoPE type: {self.rope_type}")
238
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
239
+
240
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
241
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
242
+ self.original_inv_freq = self.inv_freq
243
+
244
+ def _update_freq(self, position_ids, device):
245
+ """
246
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
247
+ 1 - growing beyond the cached sequence length (allow scaling)
248
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
249
+ """
250
+ seq_len = torch.max(position_ids) + 1
251
+ if seq_len > self.max_seq_len: # expand to seq_len
252
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
253
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
254
+ self.max_seq_len = seq_len
255
+
256
+ if seq_len < self.original_max_seq_len and self.max_seq_len > self.original_max_seq_len: # reset to original
257
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
258
+ self.max_seq_len = self.original_max_seq_len
259
+
260
+ @torch.no_grad()
261
+ def forward(self, x, position_ids):
262
+ if "dynamic" in self.rope_type:
263
+ self._update_freq(position_ids, device=x.device)
264
+
265
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
266
+ position_ids_expanded = position_ids[:, None, :].float()
267
+
268
+ device_type = x.device.type
269
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
270
+ with torch.autocast(device_type=device_type, enabled=False):
271
+ freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2)
272
+ emb = torch.cat((freqs, freqs), dim=-1)
273
+ cos, sin = emb.cos(), emb.sin()
274
+
275
+ cos, sin = cos * self.attention_scaling, sin * self.attention_scaling
276
+ return cos.to(x.dtype), sin.to(x.dtype)
277
+
278
+
279
+ class ExaoneSelfAttention(nn.Module):
280
+ def __init__(self, config: ExaoneConfig, layer_idx: Optional[int] = None):
281
+ super().__init__()
282
+ self.config = config
283
+ self.layer_idx = layer_idx
284
+ self.embed_dim = config.hidden_size
285
+ self.num_heads = config.num_attention_heads
286
+ self.head_dim = self.embed_dim // self.num_heads
287
+ self.num_key_value_heads = config.num_key_value_heads
288
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
289
+ self.attention_dropout_rate = config.attention_dropout
290
+
291
+ if self.head_dim * self.num_heads != self.embed_dim:
292
+ raise ValueError(
293
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
294
+ )
295
+
296
+ self.rotary = ExaoneRotaryEmbedding(config)
297
+
298
+ self.k_proj = nn.Linear(self.embed_dim, self.num_key_value_heads * self.head_dim, bias=False)
299
+ self.v_proj = nn.Linear(self.embed_dim, self.num_key_value_heads * self.head_dim, bias=False)
300
+ self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
301
+ self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
302
+
303
+ def forward(
304
+ self,
305
+ hidden_states: torch.Tensor,
306
+ attention_mask: Optional[torch.Tensor] = None,
307
+ position_ids: Optional[torch.LongTensor] = None,
308
+ past_key_value: Optional[Cache] = None,
309
+ output_attentions: Optional[bool] = False,
310
+ use_cache: Optional[bool] = False,
311
+ cache_position: Optional[torch.LongTensor] = None,
312
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
313
+ **kwargs,
314
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
315
+ bsz, q_len, _ = hidden_states.size()
316
+ query_states = self.q_proj(hidden_states)
317
+ key_states = self.k_proj(hidden_states)
318
+ value_states = self.v_proj(hidden_states)
319
+
320
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
321
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
322
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
323
+
324
+ if position_embeddings is None:
325
+ cos, sin = self.rotary(value_states, position_ids=position_ids)
326
+ else:
327
+ cos, sin = position_embeddings
328
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
329
+
330
+ if past_key_value is not None:
331
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
332
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
333
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
334
+
335
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
336
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
337
+
338
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
339
+
340
+ if attention_mask is not None:
341
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
342
+ attn_weights = attn_weights + causal_mask
343
+
344
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
345
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout_rate, training=self.training)
346
+ attn_output = torch.matmul(attn_weights, value_states)
347
+
348
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
349
+ raise ValueError(
350
+ f"Attention outputs should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
351
+ f" {attn_output.size()}"
352
+ )
353
+
354
+ attn_output = attn_output.transpose(1, 2).contiguous()
355
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
356
+
357
+ attn_output = self.out_proj(attn_output)
358
+
359
+ if not output_attentions:
360
+ attn_weights = None
361
+
362
+ return attn_output, attn_weights, past_key_value
363
+
364
+
365
+ class ExaoneFlashAttention(ExaoneSelfAttention):
366
+ def __init__(self, *args, **kwargs):
367
+ super().__init__(*args, **kwargs)
368
+
369
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
370
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
371
+
372
+ def forward(
373
+ self,
374
+ hidden_states: torch.Tensor,
375
+ attention_mask: Optional[torch.Tensor] = None,
376
+ position_ids: Optional[torch.LongTensor] = None,
377
+ past_key_value: Optional[Cache] = None,
378
+ output_attentions: Optional[bool] = False,
379
+ use_cache: Optional[bool] = False,
380
+ cache_position: Optional[torch.LongTensor] = None,
381
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
382
+ **kwargs,
383
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
384
+ if isinstance(past_key_value, StaticCache):
385
+ raise ValueError(
386
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
387
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
388
+ )
389
+
390
+ output_attentions = False
391
+
392
+ bsz, q_len, h_size = hidden_states.size()
393
+
394
+ query_states = self.q_proj(hidden_states)
395
+ key_states = self.k_proj(hidden_states)
396
+ value_states = self.v_proj(hidden_states)
397
+
398
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
399
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
400
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
401
+
402
+ if position_embeddings is None:
403
+ cos, sin = self.rotary(value_states, position_ids=position_ids)
404
+ else:
405
+ cos, sin = position_embeddings
406
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
407
+
408
+ if past_key_value is not None:
409
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
410
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
411
+ # Only update cache as shape of [bsz, n_head, q_len, head_dim]
412
+ # TODO: need to be fixed when transformers' KV cache layout is changed
413
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
414
+
415
+ query_states = query_states.transpose(1, 2)
416
+ key_states = key_states.transpose(1, 2)
417
+ value_states = value_states.transpose(1, 2)
418
+
419
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
420
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
421
+ # cast them back in the correct dtype just to be sure everything works as expected.
422
+ input_dtype = query_states.dtype
423
+ if input_dtype == torch.float32:
424
+ if torch.is_autocast_enabled():
425
+ target_dtype = torch.get_autocast_gpu_dtype()
426
+ # Handle the case where the model is quantized
427
+ elif hasattr(self.config, "_pre_quantization_dtype"):
428
+ target_dtype = self.config._pre_quantization_dtype
429
+ else:
430
+ target_dtype = self.q_proj.weight.dtype
431
+
432
+ logger.warning_once(
433
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
434
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
435
+ f" {target_dtype}."
436
+ )
437
+
438
+ query_states = query_states.to(target_dtype)
439
+ key_states = key_states.to(target_dtype)
440
+ value_states = value_states.to(target_dtype)
441
+
442
+ dropout_rate = self.attention_dropout_rate if self.training else 0.0
443
+
444
+ attn_output = _flash_attention_forward(
445
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate, is_causal=True
446
+ )
447
+
448
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
449
+ attn_output = self.out_proj(attn_output)
450
+
451
+ if not output_attentions:
452
+ attn_weights = None
453
+
454
+ return attn_output, attn_weights, past_key_value
455
+
456
+
457
+ class ExaoneSdpaAttention(ExaoneSelfAttention):
458
+ def __init__(self, *args, **kwargs):
459
+ super().__init__(*args, **kwargs)
460
+
461
+ def forward(
462
+ self,
463
+ hidden_states: torch.Tensor,
464
+ attention_mask: Optional[torch.Tensor] = None,
465
+ position_ids: Optional[torch.LongTensor] = None,
466
+ past_key_value: Optional[Cache] = None,
467
+ output_attentions: Optional[bool] = False,
468
+ use_cache: Optional[bool] = False,
469
+ cache_position: Optional[torch.LongTensor] = None,
470
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
471
+ **kwargs,
472
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
473
+ if output_attentions:
474
+ logger.warning_once(
475
+ "ExaoneModel is using ExaoneSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
476
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
477
+ )
478
+ return super().forward(
479
+ hidden_states=hidden_states,
480
+ attention_mask=attention_mask,
481
+ position_ids=position_ids,
482
+ past_key_value=past_key_value,
483
+ output_attentions=output_attentions,
484
+ use_cache=use_cache,
485
+ cache_position=cache_position,
486
+ position_embeddings=position_embeddings,
487
+ **kwargs,
488
+ )
489
+
490
+ bsz, q_len, _ = hidden_states.size()
491
+
492
+ query_states = self.q_proj(hidden_states)
493
+ key_states = self.k_proj(hidden_states)
494
+ value_states = self.v_proj(hidden_states)
495
+
496
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
497
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
498
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
499
+
500
+ if position_embeddings is None:
501
+ cos, sin = self.rotary(value_states, position_ids=position_ids)
502
+ else:
503
+ cos, sin = position_embeddings
504
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
505
+
506
+ if past_key_value is not None:
507
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
508
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
509
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
510
+
511
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
512
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
513
+
514
+ causal_mask = attention_mask
515
+ if attention_mask is not None:
516
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
517
+
518
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
519
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
520
+ if query_states.device.type == "cuda" and causal_mask is not None:
521
+ query_states = query_states.contiguous()
522
+ key_states = key_states.contiguous()
523
+ value_states = value_states.contiguous()
524
+
525
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
526
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
527
+ is_causal = True if causal_mask is None and q_len > 1 else False
528
+
529
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
530
+ query_states,
531
+ key_states,
532
+ value_states,
533
+ attn_mask=causal_mask,
534
+ dropout_p=self.attention_dropout_rate if self.training else 0.0,
535
+ is_causal=is_causal,
536
+ )
537
+
538
+ attn_output = attn_output.transpose(1, 2).contiguous()
539
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
540
+
541
+ attn_output = self.out_proj(attn_output)
542
+
543
+ return attn_output, None, past_key_value
544
+
545
+
546
+ class ExaoneAttention(nn.Module):
547
+ def __init__(self, config, layer_id=0):
548
+ super().__init__()
549
+ self.layer_id = layer_id
550
+ if "flash" in config._attn_implementation:
551
+ self.attention = ExaoneFlashAttention(config, self.layer_id)
552
+ elif "sdpa" in config._attn_implementation:
553
+ self.attention = ExaoneSdpaAttention(config, self.layer_id)
554
+ else:
555
+ self.attention = ExaoneSelfAttention(config, self.layer_id)
556
+
557
+ def forward(
558
+ self,
559
+ hidden_states: torch.Tensor,
560
+ attention_mask: Optional[torch.Tensor] = None,
561
+ position_ids: Optional[torch.LongTensor] = None,
562
+ past_key_value: Optional[Cache] = None,
563
+ output_attentions: Optional[bool] = False,
564
+ use_cache: Optional[bool] = False,
565
+ cache_position: Optional[torch.LongTensor] = None,
566
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
567
+ **kwargs,
568
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
569
+ return self.attention(
570
+ hidden_states=hidden_states,
571
+ attention_mask=attention_mask,
572
+ position_ids=position_ids,
573
+ past_key_value=past_key_value,
574
+ output_attentions=output_attentions,
575
+ use_cache=use_cache,
576
+ cache_position=cache_position,
577
+ position_embeddings=position_embeddings,
578
+ **kwargs,
579
+ )
580
+
581
+
582
+ class ExaoneGatedMLP(nn.Module):
583
+ def __init__(self, intermediate_size, config):
584
+ super().__init__()
585
+ self.config = config
586
+ embed_dim = config.hidden_size
587
+ self.c_fc_0 = nn.Linear(embed_dim, intermediate_size, bias=False)
588
+ self.c_fc_1 = nn.Linear(embed_dim, intermediate_size, bias=False)
589
+ self.c_proj = nn.Linear(intermediate_size, embed_dim, bias=False)
590
+ self.act = ACT2FN[config.activation_function]
591
+
592
+ def forward(self, hidden_states):
593
+ output_proj = self.c_proj(self.act(self.c_fc_0(hidden_states)) * self.c_fc_1(hidden_states))
594
+ return output_proj
595
+
596
+
597
+ class ExaoneBlock(nn.Module):
598
+ def __init__(self, config, layer_id):
599
+ super().__init__()
600
+ self.config = config
601
+ hidden_size = config.hidden_size
602
+ inner_dim = config.intermediate_size if config.intermediate_size is not None else 4 * hidden_size
603
+ self.ln_1 = ExaoneRMSNorm(hidden_size=hidden_size, eps=config.layer_norm_epsilon)
604
+ self.attn = ExaoneAttention(config, layer_id)
605
+ self.ln_2 = ExaoneRMSNorm(hidden_size=hidden_size, eps=config.layer_norm_epsilon)
606
+ self.mlp = ExaoneGatedMLP(inner_dim, config)
607
+
608
+ def forward(
609
+ self,
610
+ hidden_states: torch.Tensor,
611
+ attention_mask: Optional[torch.Tensor] = None,
612
+ position_ids: Optional[torch.LongTensor] = None,
613
+ past_key_value: Optional[Cache] = None,
614
+ output_attentions: Optional[bool] = False,
615
+ use_cache: Optional[bool] = False,
616
+ cache_position: Optional[torch.LongTensor] = None,
617
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
618
+ **kwargs,
619
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
620
+ residual = hidden_states
621
+ hidden_states = self.ln_1(hidden_states)
622
+
623
+ hidden_states, self_attn_weights, present_key_value = self.attn(
624
+ hidden_states=hidden_states,
625
+ attention_mask=attention_mask,
626
+ position_ids=position_ids,
627
+ past_key_value=past_key_value,
628
+ output_attentions=output_attentions,
629
+ use_cache=use_cache,
630
+ cache_position=cache_position,
631
+ position_embeddings=position_embeddings,
632
+ **kwargs,
633
+ )
634
+ # residual connection
635
+ hidden_states = residual + hidden_states
636
+
637
+ residual = hidden_states
638
+ hidden_states = self.ln_2(hidden_states)
639
+ hidden_states = self.mlp(hidden_states)
640
+
641
+ hidden_states = residual + hidden_states
642
+
643
+ outputs = (hidden_states,)
644
+
645
+ if output_attentions:
646
+ outputs += (self_attn_weights,)
647
+
648
+ if use_cache:
649
+ outputs += (present_key_value,)
650
+
651
+ return outputs
652
+
653
+
654
+ class ExaonePreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = ExaoneConfig
661
+ base_model_prefix = "transformer"
662
+ supports_gradient_checkpointing = True
663
+ _no_split_modules = ["ExaoneBlock"]
664
+ _skip_keys_device_placement = "past_key_values"
665
+ _supports_flash_attn_2 = True
666
+ _supports_sdpa = True
667
+ _supports_cache_class = True
668
+
669
+ def __init__(self, *inputs, **kwargs):
670
+ super().__init__(*inputs, **kwargs)
671
+
672
+ def _init_weights(self, module):
673
+ """Initialize the weights."""
674
+ if isinstance(module, (nn.Linear,)):
675
+ # Slightly different from the TF version which uses truncated_normal for initialization
676
+ # cf https://github.com/pytorch/pytorch/pull/5617
677
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
678
+ if module.bias is not None:
679
+ module.bias.data.zero_()
680
+ elif isinstance(module, nn.Embedding):
681
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
682
+ if module.padding_idx is not None:
683
+ module.weight.data[module.padding_idx].zero_()
684
+ elif isinstance(module, ExaoneRMSNorm):
685
+ module.weight.data.fill_(1.0)
686
+
687
+
688
+ EXAONE_START_DOCSTRING = r"""
689
+
690
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
691
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
692
+ etc.)
693
+
694
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
695
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
696
+ and behavior.
697
+
698
+ Parameters:
699
+ config ([`ExaoneConfig`]): Model configuration class with all the parameters of the model.
700
+ Initializing with a config file does not load the weights associated with the model, only the
701
+ configuration. Check out the `PreTrainedModel.from_pretrained` method to load the model weights.
702
+ """
703
+
704
+ EXAONE_INPUTS_DOCSTRING = r"""
705
+ Args:
706
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
707
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else
708
+ `past_key_values.get_seq_length()` (`sequence_length` of input past key value states). Indices of input
709
+ sequence tokens in the vocabulary.
710
+
711
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be
712
+ passed as `input_ids`.
713
+
714
+ `What are input IDs? <../glossary.html#input-ids>`__
715
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
716
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
717
+
718
+ - 1 for tokens that are **not masked**,
719
+ - 0 for tokens that are **masked**.
720
+
721
+ `What are attention masks? <../glossary.html#attention-mask>`__
722
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
723
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
724
+ config.max_position_embeddings - 1]`.
725
+
726
+ `What are position IDs? <../glossary.html#position-ids>`_
727
+ past_key_values (`Cache`, *optional*):
728
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
729
+ `past_key_values` output below). Can be used to speed up sequential decoding. This typically consists
730
+ in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or
731
+ `config.use_cache=True`.
732
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
733
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
734
+ This is useful if you want more control over how to convert `input_ids` indices into associated
735
+ vectors than the model's internal embedding lookup matrix.
736
+
737
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
738
+ `past_key_values`).
739
+ use_cache (`bool`, *optional*):
740
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up
741
+ decoding (see `past_key_values`).
742
+ output_attentions (`bool`, *optional*):
743
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
744
+ tensors for more detail.
745
+ output_hidden_states (`bool`, *optional*):
746
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
747
+ more detail.
748
+ return_dict (`bool`, *optional*):
749
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
750
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
751
+ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
752
+ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
753
+ the complete sequence length.
754
+ """
755
+
756
+
757
+ @add_start_docstrings(
758
+ "The bare EXAONE Model transformer outputting raw hidden-states without any specific head on top.",
759
+ EXAONE_START_DOCSTRING,
760
+ )
761
+ class ExaoneModel(ExaonePreTrainedModel):
762
+ def __init__(self, config):
763
+ super().__init__(config)
764
+ self.config = config
765
+ self.embed_dim = config.hidden_size
766
+ self.wte = nn.Embedding(config.vocab_size, self.embed_dim, self.config.pad_token_id)
767
+ self.drop = nn.Dropout(float(config.embed_dropout))
768
+ self.h = nn.ModuleList([ExaoneBlock(config, layer_id=i) for i in range(config.num_layers)])
769
+ self.ln_f = ExaoneRMSNorm(hidden_size=self.embed_dim, eps=config.layer_norm_epsilon)
770
+ self.rotary = ExaoneRotaryEmbedding(config)
771
+ self.gradient_checkpointing = False
772
+ # Initialize weights and apply final processing
773
+ self.post_init()
774
+
775
+ def get_input_embeddings(self):
776
+ return self.wte
777
+
778
+ def set_input_embeddings(self, new_embeddings):
779
+ self.wte = new_embeddings
780
+
781
+ @add_start_docstrings_to_model_forward(EXAONE_INPUTS_DOCSTRING)
782
+ @add_code_sample_docstrings(
783
+ checkpoint=_CHECKPOINT_FOR_DOC,
784
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
785
+ config_class=_CONFIG_FOR_DOC,
786
+ )
787
+ def forward(
788
+ self,
789
+ input_ids: Optional[torch.Tensor] = None,
790
+ attention_mask: Optional[torch.Tensor] = None,
791
+ position_ids: Optional[torch.Tensor] = None,
792
+ past_key_values: Optional[Cache] = None,
793
+ inputs_embeds: Optional[torch.Tensor] = None,
794
+ use_cache: Optional[bool] = None,
795
+ output_attentions: Optional[bool] = None,
796
+ output_hidden_states: Optional[bool] = None,
797
+ return_dict: Optional[bool] = None,
798
+ cache_position: Optional[torch.LongTensor] = None,
799
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
800
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
801
+ output_hidden_states = (
802
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
803
+ )
804
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
805
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
806
+
807
+ if self.gradient_checkpointing and self.training:
808
+ if use_cache:
809
+ logger.warning_once(
810
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
811
+ )
812
+ use_cache = False
813
+
814
+ if input_ids is not None and inputs_embeds is not None:
815
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
816
+ elif input_ids is not None:
817
+ batch_size, seq_length = input_ids.shape[:2]
818
+ elif inputs_embeds is not None:
819
+ batch_size, seq_length = inputs_embeds.shape[:2]
820
+ else:
821
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
822
+
823
+ return_legacy_cache = False
824
+ if (
825
+ use_cache and not isinstance(past_key_values, Cache) and not self.training
826
+ ): # kept for BC (non `Cache` `past_key_values` inputs)
827
+ return_legacy_cache = True
828
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
829
+ logger.warning_once(
830
+ "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
831
+ "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
832
+ )
833
+
834
+ if inputs_embeds is None:
835
+ inputs_embeds = self.wte(input_ids)
836
+
837
+ if cache_position is None:
838
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
839
+ cache_position = torch.arange(
840
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
841
+ )
842
+ if position_ids is None:
843
+ position_ids = cache_position.unsqueeze(0)
844
+
845
+ causal_mask = self._update_causal_mask(
846
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
847
+ )
848
+
849
+ hidden_states = inputs_embeds
850
+ hidden_states = self.drop(hidden_states)
851
+
852
+ position_embeddings = self.rotary(hidden_states, position_ids)
853
+
854
+ all_hidden_states = () if output_hidden_states else None
855
+ all_self_attns = () if output_attentions else None
856
+ next_decoder_cache = None
857
+
858
+ for block in self.h:
859
+ if output_hidden_states:
860
+ all_hidden_states = all_hidden_states + (hidden_states,)
861
+
862
+ if self.gradient_checkpointing and self.training:
863
+ outputs = self._gradient_checkpointing_func(
864
+ block.__call__,
865
+ hidden_states,
866
+ causal_mask,
867
+ position_ids,
868
+ past_key_values,
869
+ output_attentions,
870
+ use_cache,
871
+ cache_position,
872
+ position_embeddings,
873
+ )
874
+ else:
875
+ outputs = block(
876
+ hidden_states,
877
+ attention_mask=causal_mask,
878
+ position_ids=position_ids,
879
+ past_key_value=past_key_values,
880
+ output_attentions=output_attentions,
881
+ use_cache=use_cache,
882
+ cache_position=cache_position,
883
+ position_embeddings=position_embeddings,
884
+ )
885
+
886
+ hidden_states = outputs[0]
887
+ if use_cache:
888
+ next_decoder_cache = outputs[2 if output_attentions else 1]
889
+
890
+ if output_attentions:
891
+ all_self_attns += (outputs[1],)
892
+
893
+ hidden_states = self.ln_f(hidden_states)
894
+ # Add last hidden state
895
+ if output_hidden_states:
896
+ all_hidden_states += (hidden_states,)
897
+
898
+ next_cache = None
899
+ if use_cache:
900
+ next_cache = next_decoder_cache.to_legacy_cache() if return_legacy_cache else next_decoder_cache
901
+ if not return_dict:
902
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
903
+
904
+ return BaseModelOutputWithPast(
905
+ last_hidden_state=hidden_states,
906
+ past_key_values=next_cache,
907
+ hidden_states=all_hidden_states,
908
+ attentions=all_self_attns,
909
+ )
910
+
911
+ def _update_causal_mask(
912
+ self,
913
+ attention_mask: torch.Tensor,
914
+ input_tensor: torch.Tensor,
915
+ cache_position: torch.Tensor,
916
+ past_key_values: Cache,
917
+ output_attentions: bool,
918
+ ):
919
+ # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
920
+ # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
921
+ # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
922
+ # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
923
+
924
+ if self.config._attn_implementation == "flash_attention_2":
925
+ if attention_mask is not None and 0.0 in attention_mask:
926
+ return attention_mask
927
+ return None
928
+
929
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
930
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
931
+ # to infer the attention mask.
932
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
933
+ using_static_cache = isinstance(past_key_values, StaticCache)
934
+
935
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
936
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
937
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
938
+ attention_mask,
939
+ inputs_embeds=input_tensor,
940
+ past_key_values_length=past_seen_tokens,
941
+ is_training=self.training,
942
+ ):
943
+ return None
944
+
945
+ dtype, device = input_tensor.dtype, input_tensor.device
946
+ min_dtype = torch.finfo(dtype).min
947
+ sequence_length = input_tensor.shape[1]
948
+ if using_static_cache:
949
+ target_length = past_key_values.get_max_length()
950
+ else:
951
+ target_length = (
952
+ attention_mask.shape[-1]
953
+ if isinstance(attention_mask, torch.Tensor)
954
+ else past_seen_tokens + sequence_length + 1
955
+ )
956
+
957
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
958
+ causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
959
+ attention_mask,
960
+ sequence_length=sequence_length,
961
+ target_length=target_length,
962
+ dtype=dtype,
963
+ device=device,
964
+ min_dtype=min_dtype,
965
+ cache_position=cache_position,
966
+ batch_size=input_tensor.shape[0],
967
+ )
968
+
969
+ if (
970
+ self.config._attn_implementation == "sdpa"
971
+ and attention_mask is not None
972
+ and attention_mask.device.type == "cuda"
973
+ and not output_attentions
974
+ ):
975
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
976
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
977
+ # Details: https://github.com/pytorch/pytorch/issues/110213
978
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
979
+
980
+ return causal_mask
981
+
982
+
983
+ @add_start_docstrings(
984
+ """
985
+ The EXAONE Model transformer with a language modeling head on top (linear layer with weights tied to the input
986
+ embeddings).
987
+ """,
988
+ EXAONE_START_DOCSTRING,
989
+ )
990
+ class ExaoneForCausalLM(ExaonePreTrainedModel, GenerationMixin):
991
+ _tied_weights_keys = ["lm_head.weight"]
992
+
993
+ def __init__(self, config):
994
+ super().__init__(config)
995
+ self.transformer = ExaoneModel(config)
996
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
997
+ self.config = config
998
+ # Initialize weights and apply final processing
999
+ self.post_init()
1000
+
1001
+ def get_output_embeddings(self):
1002
+ return self.lm_head
1003
+
1004
+ def set_output_embeddings(self, new_embeddings):
1005
+ self.lm_head = new_embeddings
1006
+
1007
+ @add_start_docstrings_to_model_forward(EXAONE_INPUTS_DOCSTRING)
1008
+ @add_code_sample_docstrings(
1009
+ checkpoint=_CHECKPOINT_FOR_DOC,
1010
+ output_type=BaseModelOutputWithPast,
1011
+ config_class=_CONFIG_FOR_DOC,
1012
+ )
1013
+ def forward(
1014
+ self,
1015
+ input_ids: Optional[torch.Tensor] = None,
1016
+ attention_mask: Optional[torch.Tensor] = None,
1017
+ position_ids: Optional[torch.Tensor] = None,
1018
+ past_key_values: Optional[Cache] = None,
1019
+ inputs_embeds: Optional[torch.Tensor] = None,
1020
+ labels: Optional[torch.Tensor] = None,
1021
+ use_cache: Optional[bool] = None,
1022
+ output_attentions: Optional[bool] = None,
1023
+ output_hidden_states: Optional[bool] = None,
1024
+ return_dict: Optional[bool] = None,
1025
+ cache_position: Optional[torch.LongTensor] = None,
1026
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
1027
+ r"""
1028
+ Args:
1029
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1030
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1031
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1032
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1033
+
1034
+ Example:
1035
+
1036
+ ```python
1037
+ >>> from transformers import AutoModelForCausalLM, AutoTokenizer
1038
+
1039
+ >>> model = AutoModelForCausalLM.from_pretrained("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
1040
+ trust_remote_code=True)
1041
+ >>> tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct")
1042
+
1043
+ >>> prompt = "Explain how wonderful you are"
1044
+ >>> messages = [
1045
+ {"role": "system", "content": "You are a helpful assistant."},
1046
+ {"role": "user", "content": prompt}
1047
+ ]
1048
+ >>> input_ids = tokenizer.apply_chat_template(
1049
+ messages,
1050
+ tokenize=True,
1051
+ add_generation_prompt=True,
1052
+ return_tensors="pt"
1053
+ )
1054
+
1055
+ >>> output = model.generate(input_ids, max_new_tokens=128)
1056
+ >>> tokenizer.decode(output[0], skip_special_tokens=True)
1057
+ "[|system|]You are a helpful assistant.\n[|user|]Explain how wonderful you are\n[|assistant|]Thank you for your kind words! I'm here to assist you with information, answer questions, and help you in any way I can. My goal is to provide accurate, helpful, and timely responses. Whether you need help with a specific task, want to learn something new, or just need someone to talk to, I'm here for you. How can I assist you today?"
1058
+ ```
1059
+ """
1060
+
1061
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1062
+ output_hidden_states = (
1063
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1064
+ )
1065
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1066
+ transformer_outputs = self.transformer(
1067
+ input_ids,
1068
+ attention_mask=attention_mask,
1069
+ past_key_values=past_key_values,
1070
+ position_ids=position_ids,
1071
+ inputs_embeds=inputs_embeds,
1072
+ use_cache=use_cache,
1073
+ output_attentions=output_attentions,
1074
+ output_hidden_states=output_hidden_states,
1075
+ return_dict=return_dict,
1076
+ cache_position=cache_position,
1077
+ )
1078
+ hidden_states = transformer_outputs[0]
1079
+ lm_logits = self.lm_head(hidden_states)
1080
+ lm_logits = lm_logits.float()
1081
+ loss = None
1082
+ if labels is not None:
1083
+ lm_logits = lm_logits.to(torch.float32)
1084
+
1085
+ # Shift so that tokens < n predict n
1086
+ shift_logits = lm_logits[..., :-1, :].contiguous()
1087
+ shift_labels = labels[..., 1:].contiguous()
1088
+ # Flatten the tokens
1089
+ loss_fct = CrossEntropyLoss()
1090
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1091
+
1092
+ lm_logits = lm_logits.to(hidden_states.dtype)
1093
+ loss = loss.to(hidden_states.dtype)
1094
+
1095
+ if not return_dict:
1096
+ output = (lm_logits,) + transformer_outputs[1:]
1097
+ return ((loss,) + output) if loss is not None else output
1098
+
1099
+ return CausalLMOutputWithPast(
1100
+ loss=loss,
1101
+ logits=lm_logits,
1102
+ past_key_values=transformer_outputs.past_key_values,
1103
+ hidden_states=transformer_outputs.hidden_states,
1104
+ attentions=transformer_outputs.attentions,
1105
+ )
1106
+
1107
+ def prepare_inputs_for_generation(
1108
+ self,
1109
+ input_ids,
1110
+ past_key_values=None,
1111
+ attention_mask=None,
1112
+ inputs_embeds=None,
1113
+ cache_position=None,
1114
+ position_ids=None,
1115
+ use_cache=True,
1116
+ **kwargs,
1117
+ ):
1118
+ # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
1119
+ # Exception 1: when passing input_embeds, input_ids may be missing entries
1120
+ # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
1121
+ if past_key_values is not None:
1122
+ if inputs_embeds is not None: # Exception 1
1123
+ input_ids = input_ids[:, -cache_position.shape[0] :]
1124
+ elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
1125
+ input_ids = input_ids[:, cache_position]
1126
+
1127
+ if attention_mask is not None and position_ids is None:
1128
+ # create position_ids on the fly for batch generation
1129
+ position_ids = attention_mask.long().cumsum(-1) - 1
1130
+ position_ids.masked_fill_(attention_mask == 0, 1)
1131
+ if past_key_values:
1132
+ position_ids = position_ids[:, -input_ids.shape[1] :]
1133
+
1134
+ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
1135
+ position_ids = position_ids.clone(memory_format=torch.contiguous_format)
1136
+
1137
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1138
+ if inputs_embeds is not None and cache_position[0] == 0:
1139
+ model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
1140
+ else:
1141
+ model_inputs = {"input_ids": input_ids, "inputs_embeds": None}
1142
+
1143
+ if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2:
1144
+ if inputs_embeds is not None:
1145
+ batch_size, sequence_length, _ = inputs_embeds.shape
1146
+ device = inputs_embeds.device
1147
+ else:
1148
+ batch_size, sequence_length = input_ids.shape
1149
+ device = input_ids.device
1150
+
1151
+ dtype = self.lm_head.weight.dtype
1152
+ min_dtype = torch.finfo(dtype).min
1153
+
1154
+ attention_mask = _prepare_4d_causal_attention_mask_with_cache_position(
1155
+ attention_mask,
1156
+ sequence_length=sequence_length,
1157
+ target_length=past_key_values.get_max_length(),
1158
+ dtype=dtype,
1159
+ device=device,
1160
+ min_dtype=min_dtype,
1161
+ cache_position=cache_position,
1162
+ batch_size=batch_size,
1163
+ )
1164
+
1165
+ model_inputs.update(
1166
+ {
1167
+ "position_ids": position_ids,
1168
+ "cache_position": cache_position,
1169
+ "past_key_values": past_key_values,
1170
+ "use_cache": use_cache,
1171
+ "attention_mask": attention_mask,
1172
+ }
1173
+ )
1174
+ return model_inputs
1175
+
1176
+
1177
+ @add_start_docstrings(
1178
+ """
1179
+ The EXAONE Model transformer with a sequence classification head on top (linear layer).
1180
+
1181
+ [`ExaoneForSequenceClassification`] uses the last token in order to do the classification, as
1182
+ other causal models (e.g. GPT-1) do.
1183
+
1184
+ Since it does classification on the last token, it requires to know the position of the last token. If a
1185
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each
1186
+ row. If no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot
1187
+ guess the padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take
1188
+ the last value in each row of the batch).
1189
+ """,
1190
+ EXAONE_START_DOCSTRING,
1191
+ )
1192
+ class ExaoneForSequenceClassification(ExaonePreTrainedModel):
1193
+ def __init__(self, config):
1194
+ super().__init__(config)
1195
+ self.num_labels = config.num_labels
1196
+ self.transformer = ExaoneModel(config)
1197
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
1198
+
1199
+ # Initialize weights and apply final processing
1200
+ self.post_init()
1201
+
1202
+ @add_start_docstrings_to_model_forward(EXAONE_INPUTS_DOCSTRING)
1203
+ @add_code_sample_docstrings(
1204
+ checkpoint=_CHECKPOINT_FOR_DOC,
1205
+ output_type=SequenceClassifierOutputWithPast,
1206
+ config_class=_CONFIG_FOR_DOC,
1207
+ )
1208
+ def forward(
1209
+ self,
1210
+ input_ids: Optional[torch.Tensor] = None,
1211
+ attention_mask: Optional[torch.Tensor] = None,
1212
+ position_ids: Optional[torch.Tensor] = None,
1213
+ past_key_values: Optional[Cache] = None,
1214
+ inputs_embeds: Optional[torch.Tensor] = None,
1215
+ labels: Optional[torch.Tensor] = None,
1216
+ use_cache: Optional[bool] = None,
1217
+ output_attentions: Optional[bool] = None,
1218
+ output_hidden_states: Optional[bool] = None,
1219
+ return_dict: Optional[bool] = None,
1220
+ ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1221
+ r"""
1222
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1223
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1224
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1225
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1226
+ """
1227
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1228
+
1229
+ transformer_outputs = self.transformer(
1230
+ input_ids,
1231
+ attention_mask=attention_mask,
1232
+ position_ids=position_ids,
1233
+ past_key_values=past_key_values,
1234
+ inputs_embeds=inputs_embeds,
1235
+ use_cache=use_cache,
1236
+ output_attentions=output_attentions,
1237
+ output_hidden_states=output_hidden_states,
1238
+ return_dict=return_dict,
1239
+ )
1240
+ hidden_states = transformer_outputs[0]
1241
+ logits = self.score(hidden_states)
1242
+
1243
+ if input_ids is not None:
1244
+ batch_size, sequence_length = input_ids.shape[:2]
1245
+ else:
1246
+ batch_size, sequence_length = inputs_embeds.shape[:2]
1247
+
1248
+ if self.config.pad_token_id is None and batch_size != 1:
1249
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1250
+ if self.config.pad_token_id is None:
1251
+ sequence_lengths = -1
1252
+ else:
1253
+ if input_ids is not None:
1254
+ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
1255
+ sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
1256
+ sequence_lengths = sequence_lengths % input_ids.shape[-1]
1257
+ sequence_lengths = sequence_lengths.to(logits.device)
1258
+ else:
1259
+ sequence_lengths = -1
1260
+ logger.warning(
1261
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1262
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1263
+ )
1264
+
1265
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1266
+
1267
+ loss = None
1268
+ if labels is not None:
1269
+ labels = labels.to(logits.device)
1270
+ if self.config.problem_type is None:
1271
+ if self.num_labels == 1:
1272
+ self.config.problem_type = "regression"
1273
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1274
+ self.config.problem_type = "single_label_classification"
1275
+ else:
1276
+ self.config.problem_type = "multi_label_classification"
1277
+
1278
+ if self.config.problem_type == "regression":
1279
+ loss_fct = MSELoss()
1280
+ if self.num_labels == 1:
1281
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1282
+ else:
1283
+ loss = loss_fct(pooled_logits, labels)
1284
+ elif self.config.problem_type == "single_label_classification":
1285
+ loss_fct = CrossEntropyLoss()
1286
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1287
+ elif self.config.problem_type == "multi_label_classification":
1288
+ loss_fct = BCEWithLogitsLoss()
1289
+ loss = loss_fct(pooled_logits, labels)
1290
+ if not return_dict:
1291
+ output = (pooled_logits,) + transformer_outputs[1:]
1292
+ return ((loss,) + output) if loss is not None else output
1293
+
1294
+ return SequenceClassifierOutputWithPast(
1295
+ loss=loss,
1296
+ logits=pooled_logits,
1297
+ past_key_values=transformer_outputs.past_key_values,
1298
+ hidden_states=transformer_outputs.hidden_states,
1299
+ attentions=transformer_outputs.attentions,
1300
+ )
1301
+
1302
+
1303
+ @add_start_docstrings(
1304
+ """
1305
+ The EXAONE Model transformer with a span classification head on top for extractive question-answering tasks like
1306
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1307
+ """,
1308
+ EXAONE_START_DOCSTRING,
1309
+ )
1310
+ class ExaoneForQuestionAnswering(ExaonePreTrainedModel):
1311
+ def __init__(self, config):
1312
+ super().__init__(config)
1313
+ self.num_labels = config.num_labels
1314
+ self.transformer = ExaoneModel(config)
1315
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1316
+
1317
+ # Model parallel
1318
+ self.model_parallel = False
1319
+ self.device_map = None
1320
+
1321
+ # Initialize weights and apply final processing
1322
+ self.post_init()
1323
+
1324
+ def forward(
1325
+ self,
1326
+ input_ids: Optional[torch.LongTensor] = None,
1327
+ attention_mask: Optional[torch.FloatTensor] = None,
1328
+ position_ids: Optional[torch.LongTensor] = None,
1329
+ past_key_values: Optional[Cache] = None,
1330
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1331
+ start_positions: Optional[torch.LongTensor] = None,
1332
+ end_positions: Optional[torch.LongTensor] = None,
1333
+ output_attentions: Optional[bool] = None,
1334
+ output_hidden_states: Optional[bool] = None,
1335
+ return_dict: Optional[bool] = None,
1336
+ ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
1337
+ r"""
1338
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1339
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1340
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
1341
+ sequence are not taken into account for computing the loss.
1342
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1343
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1344
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the
1345
+ sequence are not taken into account for computing the loss.
1346
+ """
1347
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1348
+
1349
+ outputs = self.transformer(
1350
+ input_ids,
1351
+ attention_mask=attention_mask,
1352
+ position_ids=position_ids,
1353
+ past_key_values=past_key_values,
1354
+ inputs_embeds=inputs_embeds,
1355
+ output_attentions=output_attentions,
1356
+ output_hidden_states=output_hidden_states,
1357
+ return_dict=return_dict,
1358
+ )
1359
+
1360
+ sequence_output = outputs[0]
1361
+
1362
+ logits = self.qa_outputs(sequence_output)
1363
+ start_logits, end_logits = logits.split(1, dim=-1)
1364
+ start_logits = start_logits.squeeze(-1).contiguous()
1365
+ end_logits = end_logits.squeeze(-1).contiguous()
1366
+
1367
+ total_loss = None
1368
+ if start_positions is not None and end_positions is not None:
1369
+ # If we are on multi-GPU, split add a dimension
1370
+ if len(start_positions.size()) > 1:
1371
+ start_positions = start_positions.squeeze(-1).to(start_logits.device)
1372
+ if len(end_positions.size()) > 1:
1373
+ end_positions = end_positions.squeeze(-1).to(end_logits.device)
1374
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1375
+ ignored_index = start_logits.size(1)
1376
+ start_positions = start_positions.clamp(0, ignored_index)
1377
+ end_positions = end_positions.clamp(0, ignored_index)
1378
+
1379
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1380
+ start_loss = loss_fct(start_logits, start_positions)
1381
+ end_loss = loss_fct(end_logits, end_positions)
1382
+ total_loss = (start_loss + end_loss) / 2
1383
+
1384
+ if not return_dict:
1385
+ output = (start_logits, end_logits) + outputs[2:]
1386
+ return ((total_loss,) + output) if total_loss is not None else output
1387
+
1388
+ return QuestionAnsweringModelOutput(
1389
+ loss=total_loss,
1390
+ start_logits=start_logits,
1391
+ end_logits=end_logits,
1392
+ hidden_states=outputs.hidden_states,
1393
+ attentions=outputs.attentions,
1394
+ )