Aekanun commited on
Commit
4547d7b
·
1 Parent(s): 29301f8

complied Conv

Browse files
unsloth_compiled_cache/Conv1d.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Unsloth Zoo - Utilities for Unsloth
3
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU Lesser General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU Lesser General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
18
+ from torch import Tensor
19
+ import torch
20
+ from torch.nn import functional as F
21
+
22
+
23
+ def forward(self, input: Tensor) -> Tensor:
24
+ return self._conv_forward(input, self.weight, self.bias).to(input.dtype)
unsloth_compiled_cache/Conv2d.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Unsloth Zoo - Utilities for Unsloth
3
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU Lesser General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU Lesser General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
18
+ from torch import Tensor
19
+ import torch
20
+ from torch.nn import functional as F
21
+
22
+
23
+ def forward(self, input: Tensor) -> Tensor:
24
+ return self._conv_forward(input, self.weight, self.bias).to(input.dtype)
unsloth_compiled_cache/Conv3d.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Unsloth Zoo - Utilities for Unsloth
3
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU Lesser General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU Lesser General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
18
+ from torch import Tensor
19
+ import torch
20
+ from torch.nn import functional as F
21
+
22
+
23
+ def forward(self, input: Tensor) -> Tensor:
24
+ return self._conv_forward(input, self.weight, self.bias).to(input.dtype)
unsloth_compiled_cache/ConvTranspose1d.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Unsloth Zoo - Utilities for Unsloth
3
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU Lesser General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU Lesser General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
18
+ from torch import Tensor
19
+ import torch
20
+ from torch.nn import functional as F
21
+ from transformers.models.mllama.modeling_mllama import (F, List, Optional, Tuple, nn)
22
+
23
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
24
+ if self.padding_mode != 'zeros':
25
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')
26
+
27
+ assert isinstance(self.padding, tuple)
28
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
29
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
30
+ num_spatial_dims = 1
31
+ output_padding = self._output_padding(
32
+ input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
33
+ num_spatial_dims, self.dilation) # type: ignore[arg-type]
34
+ return F.conv_transpose1d(
35
+ input, self.weight, self.bias, self.stride, self.padding,
36
+ output_padding, self.groups, self.dilation).to(input.dtype)
unsloth_compiled_cache/ConvTranspose2d.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Unsloth Zoo - Utilities for Unsloth
3
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU Lesser General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU Lesser General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
18
+ from torch import Tensor
19
+ import torch
20
+ from torch.nn import functional as F
21
+ from transformers.models.mllama.modeling_mllama import (F, List, Optional, Tuple, nn)
22
+
23
+ def forward(self, input: Tensor, output_size: Optional[List[int]] = None) -> Tensor:
24
+ if self.padding_mode != 'zeros':
25
+ raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')
26
+
27
+ assert isinstance(self.padding, tuple)
28
+ # One cannot replace List by Tuple or Sequence in "_output_padding" because
29
+ # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
30
+ num_spatial_dims = 2
31
+ output_padding = self._output_padding(
32
+ input, output_size, self.stride, self.padding, self.kernel_size, # type: ignore[arg-type]
33
+ num_spatial_dims, self.dilation) # type: ignore[arg-type]
34
+
35
+ return F.conv_transpose2d(
36
+ input, self.weight, self.bias, self.stride, self.padding,
37
+ output_padding, self.groups, self.dilation).to(input.dtype)
unsloth_compiled_cache/unsloth_compiled_module_mllama.py ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Unsloth Zoo - Utilities for Unsloth
3
+ # Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
4
+ #
5
+ # This program is free software: you can redistribute it and/or modify
6
+ # it under the terms of the GNU Lesser General Public License as published by
7
+ # the Free Software Foundation, either version 3 of the License, or
8
+ # (at your option) any later version.
9
+ #
10
+ # This program is distributed in the hope that it will be useful,
11
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
12
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13
+ # GNU General Public License for more details.
14
+ #
15
+ # You should have received a copy of the GNU Lesser General Public License
16
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
17
+
18
+ import torch
19
+ from unsloth_zoo.loss_utils import fused_linear_cross_entropy
20
+
21
+ scaled_dot_product_attention = torch.nn.functional.scaled_dot_product_attention
22
+ @torch.compiler.disable(recursive = False)
23
+ def disable_compile_scaled_dot_product_attention(*args, **kwargs):
24
+ return scaled_dot_product_attention(*args, **kwargs)
25
+ pass
26
+
27
+
28
+ torch_compile_options = {'epilogue_fusion': True, 'max_autotune': False, 'shape_padding': True, 'trace.enabled': False, 'triton.cudagraphs': False}
29
+ from torch import Tensor
30
+ import torch
31
+ from torch.nn import functional as F
32
+ from transformers.models.mllama.modeling_mllama import (F, math, Optional, Tuple, torch, nn, ACT2FN, Cache, ROPE_INIT_FUNCTIONS, MllamaTextConfig, MllamaVisionConfig)
33
+
34
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
35
+ def _prepare_cross_attention_mask(cross_attention_mask: torch.Tensor,
36
+ num_vision_tokens: int,
37
+ dtype: str,) -> Tuple[torch.Tensor, torch.Tensor]:
38
+ # reshape so it can be used by attn module
39
+ batch_size, text_total_length, *_ = cross_attention_mask.shape
40
+ cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
41
+ cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
42
+ cross_attention_mask = cross_attention_mask.unsqueeze(1)
43
+
44
+ # invert the mask
45
+ inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
46
+ cross_attention_mask = inverted_cross_attn_mask.masked_fill(inverted_cross_attn_mask.to(torch.bool), torch.finfo(dtype).min)
47
+
48
+ # apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
49
+ # last dimension contains negative infinity values, otherwise it's 1
50
+ negative_inf_value = torch.finfo(dtype).min
51
+ full_text_row_masked_out_mask = ((cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None])
52
+ cross_attention_mask *= full_text_row_masked_out_mask
53
+
54
+ return cross_attention_mask!=torch.finfo(cross_attention_mask.dtype).min, full_text_row_masked_out_mask
55
+
56
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
57
+ def _prepare_aspect_ratio_attention_mask(aspect_ratio_mask: torch.Tensor,
58
+ num_patches: int,
59
+ target_length: int,
60
+ dtype: torch.dtype,) -> torch.Tensor:
61
+ # Expand aspect ratio mask to target_length
62
+ batch_size, max_num_tiles = aspect_ratio_mask.shape
63
+ attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
64
+ attention_mask = attention_mask.repeat(1, 1, target_length, 1)
65
+
66
+ # Mask padding patches
67
+ pad_patches = target_length - num_patches
68
+ attention_mask[:, :, -pad_patches:] = 0
69
+
70
+ # Invert the mask (0 -> 1, 1 -> 0)
71
+ attention_mask = 1 - attention_mask
72
+
73
+ # Reshape to 2D and create 4D attention mask
74
+ # (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
75
+ attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1)
76
+ attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.finfo(dtype).min
77
+ attention_mask = attention_mask.unsqueeze(1)
78
+
79
+ return attention_mask!=torch.finfo(attention_mask.dtype).min
80
+
81
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
82
+ def MllamaPrecomputedAspectRatioEmbedding_forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
83
+ embeddings = self.embedding(aspect_ratio_ids)
84
+ embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
85
+
86
+ if self.is_gated:
87
+ embeddings = embeddings * self.gate.tanh()
88
+
89
+ hidden_state = hidden_state + embeddings
90
+ return hidden_state
91
+
92
+ class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
93
+ def __init__(self, config: MllamaVisionConfig, is_gated: bool = True):
94
+ super().__init__()
95
+ self.max_num_tiles = config.max_num_tiles
96
+ self.hidden_size = config.hidden_size
97
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
98
+ self.is_gated = is_gated
99
+
100
+ self.embedding = nn.Embedding(self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size)
101
+ if is_gated:
102
+ self.gate = nn.Parameter(torch.zeros(1))
103
+
104
+ def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
105
+ return MllamaPrecomputedAspectRatioEmbedding_forward(self, hidden_state, aspect_ratio_ids)
106
+
107
+
108
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
109
+ def MllamaPrecomputedPositionEmbedding_forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
110
+ # position embeddings
111
+ gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
112
+ hidden_state = hidden_state + gated_position_embedding.view(1, 1, self.num_patches, self.hidden_size)
113
+
114
+ # precomputed tile position embeddings
115
+ tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
116
+ batch_size = hidden_state.shape[0]
117
+ tile_position_embedding = tile_position_embedding.reshape(
118
+ batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
119
+ )
120
+ gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
121
+ hidden_state = hidden_state + gated_tile_position_embedding
122
+
123
+ return hidden_state
124
+
125
+ class MllamaPrecomputedPositionEmbedding(nn.Module):
126
+ def __init__(self, config: MllamaVisionConfig):
127
+ super().__init__()
128
+ self.max_num_tiles = config.max_num_tiles
129
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
130
+ self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
131
+ self.hidden_size = config.hidden_size
132
+ self.scale = config.hidden_size**-0.5
133
+
134
+ self.gate = nn.Parameter(torch.zeros(1))
135
+
136
+ # position embedding
137
+ position_embedding = torch.randn(self.num_patches, self.hidden_size)
138
+ self.embedding = nn.Parameter(self.scale * position_embedding)
139
+
140
+ # tile position embedding
141
+ self.tile_embedding = nn.Embedding(
142
+ self.max_aspect_ratio_id + 1, self.max_num_tiles * self.num_patches * self.hidden_size
143
+ )
144
+
145
+ def forward(self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor:
146
+ return MllamaPrecomputedPositionEmbedding_forward(self, hidden_state, aspect_ratio_ids)
147
+
148
+
149
+ @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
150
+ def MllamaVisionMLP_forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
151
+ hidden_states = self.fc1(hidden_states)
152
+ hidden_states = self.activation_fn(hidden_states)
153
+ hidden_states = self.fc2(hidden_states)
154
+ return hidden_states
155
+
156
+ class MllamaVisionMLP(nn.Module):
157
+ def __init__(self, config):
158
+ super().__init__()
159
+ self.config = config
160
+ self.activation_fn = ACT2FN[config.hidden_act]
161
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
162
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
163
+
164
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
165
+ return MllamaVisionMLP_forward(self, hidden_states)
166
+
167
+
168
+ @torch.compiler.disable(recursive = False)
169
+ def MllamaVisionAttention_forward(
170
+ self,
171
+ hidden_state: torch.Tensor,
172
+ attention_mask: Optional[torch.Tensor] = None,
173
+ output_attentions: bool = None,
174
+ ) -> torch.Tensor:
175
+ query = self.q_proj(hidden_state)
176
+ key = self.k_proj(hidden_state)
177
+ value = self.v_proj(hidden_state)
178
+
179
+ batch_size, q_seq_len, _ = query.shape
180
+ _, kv_seq_len, _ = key.shape
181
+
182
+ query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
183
+ key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
184
+ value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim).transpose(1, 2)
185
+
186
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(self.head_dim)
187
+
188
+ if attention_mask is not None: # no matter the length, we just slice it
189
+ causal_mask = attention_mask[:, :, :, : key.shape[-2]]
190
+ attn_weights = attn_weights + causal_mask
191
+
192
+ # upcast attention to fp32
193
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
194
+ attn_output = torch.matmul(attn_weights, value)
195
+
196
+ attn_output = attn_output.transpose(1, 2).contiguous()
197
+ attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
198
+
199
+ output = self.o_proj(attn_output)
200
+
201
+ if not output_attentions:
202
+ attn_weights = None
203
+
204
+ return output, attn_weights
205
+
206
+ class MllamaVisionAttention(nn.Module):
207
+ def __init__(self, config: MllamaVisionConfig):
208
+ super().__init__()
209
+
210
+ self.embed_dim = config.hidden_size
211
+ self.num_heads = config.attention_heads
212
+ self.head_dim = config.hidden_size // config.attention_heads
213
+
214
+ self.q_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
215
+ self.k_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
216
+ self.v_proj = nn.Linear(self.embed_dim, self.num_heads * self.head_dim, bias=False)
217
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.embed_dim, bias=False)
218
+
219
+ def forward(
220
+ self,
221
+ hidden_state: torch.Tensor,
222
+ attention_mask: Optional[torch.Tensor] = None,
223
+ output_attentions: bool = None,
224
+ ) -> torch.Tensor:
225
+ return MllamaVisionAttention_forward(self, hidden_state, attention_mask, output_attentions)
226
+
227
+
228
+ @torch.compiler.disable(recursive = False)
229
+ def MllamaVisionSdpaAttention_forward(
230
+ self,
231
+ hidden_state: torch.Tensor,
232
+ attention_mask: Optional[torch.Tensor] = None,
233
+ output_attentions: bool = None,
234
+ ) -> torch.Tensor:
235
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
236
+ if output_attentions: raise RuntimeError('Unsloth: Not supported')
237
+
238
+ query = self.q_proj(hidden_state)
239
+ key = self.k_proj(hidden_state)
240
+ value = self.v_proj(hidden_state)
241
+
242
+ batch_size, q_seq_len, _ = query.shape
243
+ _, kv_seq_len, _ = key.shape
244
+
245
+ query = query.view(batch_size, q_seq_len, self.num_heads, self.head_dim)
246
+ key = key.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
247
+ value = value.view(batch_size, kv_seq_len, self.num_heads, self.head_dim)
248
+
249
+ query = query.transpose(1, 2)
250
+ key = key.transpose(1, 2)
251
+ value = value.transpose(1, 2)
252
+
253
+ attn_output = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
254
+
255
+ attn_output = attn_output.transpose(1, 2).contiguous()
256
+ attn_output = attn_output.reshape(batch_size, q_seq_len, -1)
257
+
258
+ output = self.o_proj(attn_output)
259
+
260
+ return output, None
261
+
262
+ class MllamaVisionSdpaAttention(MllamaVisionAttention):
263
+ # Adapted from MllamaVisionAttention
264
+ def forward(
265
+ self,
266
+ hidden_state: torch.Tensor,
267
+ attention_mask: Optional[torch.Tensor] = None,
268
+ output_attentions: bool = None,
269
+ ) -> torch.Tensor:
270
+ return MllamaVisionSdpaAttention_forward(self, hidden_state, attention_mask, output_attentions)
271
+
272
+
273
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
274
+ def MllamaTextRMSNorm_forward(self, hidden_states):
275
+ input_dtype = hidden_states.dtype
276
+ hidden_states = hidden_states.to(torch.float32)
277
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
278
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
279
+ return self.weight * hidden_states.to(input_dtype)
280
+
281
+ class MllamaTextRMSNorm(nn.Module):
282
+ def __init__(self, hidden_size, eps=1e-6):
283
+ """
284
+ MllamaTextRMSNorm is equivalent to T5LayerNorm
285
+ """
286
+ super().__init__()
287
+ self.weight = nn.Parameter(torch.ones(hidden_size))
288
+ self.variance_epsilon = eps
289
+
290
+ def forward(self, hidden_states):
291
+ return MllamaTextRMSNorm_forward(self, hidden_states)
292
+
293
+ def extra_repr(self):
294
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
295
+
296
+
297
+ @torch.compiler.disable(recursive = False)
298
+ def MllamaTextCrossAttention_forward(
299
+ self,
300
+ hidden_states: torch.Tensor,
301
+ cross_attention_states: Optional[torch.Tensor] = None,
302
+ past_key_value: Optional[Cache] = None,
303
+ attention_mask: Optional[torch.Tensor] = None,
304
+ output_attentions: bool = False,
305
+ use_cache: bool = None,
306
+ cache_position: Optional[torch.LongTensor] = None,
307
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
308
+ """Input shape: Batch x Time x Channel"""
309
+ bsz, q_len, _ = hidden_states.size()
310
+ query_states = self.q_proj(hidden_states)
311
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
312
+ query_states = self.q_norm(query_states)
313
+
314
+ if cross_attention_states is not None:
315
+ key_states = self.k_proj(cross_attention_states)
316
+ value_states = self.v_proj(cross_attention_states)
317
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
318
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
319
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
320
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
321
+
322
+ key_states = self.k_norm(key_states)
323
+ if past_key_value is not None:
324
+ # if we have a new image + new tokens, we only computed key_states on that new image
325
+ # we still update the cross key states, past_image, new_image. And use it!
326
+ key_states, value_states = past_key_value.update(
327
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
328
+ )
329
+ elif cache_position[0] != 0:
330
+ key_states, value_states = (
331
+ past_key_value.key_cache[self.layer_idx],
332
+ past_key_value.value_cache[self.layer_idx],
333
+ )
334
+ else:
335
+ raise ValueError(
336
+ "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
337
+ )
338
+
339
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
340
+
341
+ if attention_mask is not None: # no matter the length, we just slice it
342
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
343
+ attn_weights = attn_weights + causal_mask
344
+
345
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
346
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
347
+ attn_output = torch.matmul(attn_weights, value_states)
348
+ attn_output = attn_output.transpose(1, 2).contiguous()
349
+ attn_output = attn_output.reshape(bsz, q_len, -1)
350
+ attn_output = self.o_proj(attn_output)
351
+
352
+ if not output_attentions:
353
+ attn_weights = None
354
+
355
+ return attn_output, attn_weights, past_key_value
356
+
357
+ class MllamaTextCrossAttention(nn.Module):
358
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
359
+
360
+ def __init__(
361
+ self,
362
+ config: Optional[MllamaTextConfig] = None,
363
+ layer_idx: Optional[int] = None,
364
+ ):
365
+ super().__init__()
366
+ self.config = config
367
+ self.num_heads = self.config.num_attention_heads
368
+ self.num_key_value_heads = self.config.num_key_value_heads
369
+ self.dropout = config.dropout
370
+ self.hidden_size = config.hidden_size
371
+ self.head_dim = config.hidden_size // self.num_heads
372
+ self.layer_idx = layer_idx
373
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
374
+
375
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
376
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
377
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
378
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
379
+
380
+ self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
381
+ self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
382
+
383
+ def forward(
384
+ self,
385
+ hidden_states: torch.Tensor,
386
+ cross_attention_states: Optional[torch.Tensor] = None,
387
+ past_key_value: Optional[Cache] = None,
388
+ attention_mask: Optional[torch.Tensor] = None,
389
+ output_attentions: bool = False,
390
+ use_cache: bool = None,
391
+ cache_position: Optional[torch.LongTensor] = None,
392
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
393
+ return MllamaTextCrossAttention_forward(self, hidden_states, cross_attention_states, past_key_value, attention_mask, output_attentions, use_cache, cache_position)
394
+
395
+
396
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
397
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
398
+ """
399
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
400
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
401
+ """
402
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
403
+ if n_rep == 1:
404
+ return hidden_states
405
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
406
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
407
+
408
+
409
+ @torch.compiler.disable(recursive = False)
410
+ def MllamaTextCrossSdpaAttention_forward(
411
+ self,
412
+ hidden_states: torch.Tensor,
413
+ cross_attention_states: Optional[torch.Tensor] = None,
414
+ past_key_value: Optional[Cache] = None,
415
+ attention_mask: Optional[torch.Tensor] = None,
416
+ output_attentions: bool = False,
417
+ use_cache: bool = None,
418
+ cache_position: Optional[torch.LongTensor] = None,
419
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
420
+ """Input shape: Batch x Time x Channel"""
421
+ if output_attentions: raise RuntimeError('Unsloth: Not supported')
422
+
423
+ bsz, q_len, _ = hidden_states.size()
424
+ query_states = self.q_proj(hidden_states)
425
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
426
+ query_states = self.q_norm(query_states)
427
+
428
+ if cross_attention_states is not None:
429
+ key_states = self.k_proj(cross_attention_states)
430
+ value_states = self.v_proj(cross_attention_states)
431
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
432
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
433
+
434
+ if past_key_value is not None:
435
+ # if we have a new image + new tokens, we only computed key_states on that new image
436
+ # we still update the cross key states, past_image, new_image. And use it!
437
+ key_states, value_states = past_key_value.update(
438
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
439
+ )
440
+ elif cache_position[0] != 0:
441
+ key_states, value_states = (
442
+ past_key_value.key_cache[self.layer_idx],
443
+ past_key_value.value_cache[self.layer_idx],
444
+ )
445
+ else:
446
+ raise ValueError(
447
+ "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
448
+ )
449
+
450
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
451
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
452
+
453
+ key_states = self.k_norm(key_states)
454
+
455
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
456
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
457
+ if query_states.device.type == "cuda" and attention_mask is not None:
458
+ query_states = query_states.contiguous()
459
+ key_states = key_states.contiguous()
460
+ value_states = value_states.contiguous()
461
+
462
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
463
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
464
+ is_causal = True if attention_mask is None and q_len > 1 else False
465
+
466
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
467
+ query_states,
468
+ key_states,
469
+ value_states,
470
+ attn_mask=attention_mask,
471
+ dropout_p=self.dropout if self.training else 0.0,
472
+ is_causal=is_causal,
473
+ )
474
+
475
+ attn_output = attn_output.transpose(1, 2).contiguous()
476
+ attn_output = attn_output.reshape(bsz, q_len, -1)
477
+ attn_output = self.o_proj(attn_output)
478
+
479
+ return attn_output, None, past_key_value
480
+
481
+ class MllamaTextCrossSdpaAttention(MllamaTextCrossAttention):
482
+ """
483
+ Mllama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
484
+ `MllamaTextCrossAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
485
+ SDPA API.
486
+ """
487
+
488
+ # Adapted from MllamaTextCrossAttention.forward
489
+ def forward(
490
+ self,
491
+ hidden_states: torch.Tensor,
492
+ cross_attention_states: Optional[torch.Tensor] = None,
493
+ past_key_value: Optional[Cache] = None,
494
+ attention_mask: Optional[torch.Tensor] = None,
495
+ output_attentions: bool = False,
496
+ use_cache: bool = None,
497
+ cache_position: Optional[torch.LongTensor] = None,
498
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
499
+ return MllamaTextCrossSdpaAttention_forward(self, hidden_states, cross_attention_states, past_key_value, attention_mask, output_attentions, use_cache, cache_position)
500
+
501
+
502
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
503
+ def rotate_half(x):
504
+ """Rotates half the hidden dims of the input."""
505
+ x1 = x[..., : x.shape[-1] // 2]
506
+ x2 = x[..., x.shape[-1] // 2 :]
507
+ return torch.cat((-x2, x1), dim=-1)
508
+
509
+
510
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
511
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
512
+ """Applies Rotary Position Embedding to the query and key tensors.
513
+
514
+ Args:
515
+ q (`torch.Tensor`): The query tensor.
516
+ k (`torch.Tensor`): The key tensor.
517
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
518
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
519
+ position_ids (`torch.Tensor`, *optional*):
520
+ Deprecated and unused.
521
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
522
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
523
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
524
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
525
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
526
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
527
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
528
+ Returns:
529
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
530
+ """
531
+ cos = cos.unsqueeze(unsqueeze_dim)
532
+ sin = sin.unsqueeze(unsqueeze_dim)
533
+ q_embed = (q * cos) + (rotate_half(q) * sin)
534
+ k_embed = (k * cos) + (rotate_half(k) * sin)
535
+ return q_embed, k_embed
536
+
537
+
538
+ @torch.compiler.disable(recursive = False)
539
+ def MllamaTextSelfAttention_forward(
540
+ self,
541
+ hidden_states: torch.Tensor,
542
+ attention_mask: torch.Tensor,
543
+ position_embeddings: torch.Tensor,
544
+ output_attentions: bool = False,
545
+ use_cache: bool = False,
546
+ past_key_value=None,
547
+ cache_position=None,
548
+ **kwargs,
549
+ ):
550
+ bsz, q_len, _ = hidden_states.size()
551
+
552
+ query_states = self.q_proj(hidden_states)
553
+ key_states = self.k_proj(hidden_states)
554
+ value_states = self.v_proj(hidden_states)
555
+
556
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
557
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
558
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
559
+
560
+ cos, sin = position_embeddings
561
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
562
+
563
+ if past_key_value is not None:
564
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
565
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
566
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
567
+
568
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
569
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
570
+
571
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
572
+
573
+ if attention_mask is not None: # no matter the length, we just slice it
574
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
575
+ attn_weights = attn_weights + causal_mask
576
+
577
+ # upcast attention to fp32
578
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
579
+ attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
580
+ attn_output = torch.matmul(attn_weights, value_states)
581
+
582
+ attn_output = attn_output.transpose(1, 2).contiguous()
583
+ attn_output = attn_output.view(bsz, q_len, -1)
584
+
585
+ attn_output = self.o_proj(attn_output)
586
+
587
+ if not output_attentions:
588
+ attn_weights = None
589
+
590
+ return attn_output, attn_weights, past_key_value
591
+
592
+ class MllamaTextSelfAttention(nn.Module):
593
+ def __init__(self, config: MllamaTextConfig, layer_idx: int):
594
+ super().__init__()
595
+ self.config = config
596
+ self.num_heads = config.num_attention_heads
597
+ self.dropout = config.dropout
598
+ self.hidden_size = config.hidden_size
599
+ self.num_key_value_heads = config.num_key_value_heads
600
+ self.head_dim = config.hidden_size // self.num_heads
601
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
602
+ self.rope_theta = config.rope_theta
603
+ self.layer_idx = layer_idx
604
+
605
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
606
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
607
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
608
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
609
+
610
+ def forward(
611
+ self,
612
+ hidden_states: torch.Tensor,
613
+ attention_mask: torch.Tensor,
614
+ position_embeddings: torch.Tensor,
615
+ output_attentions: bool = False,
616
+ use_cache: bool = False,
617
+ past_key_value=None,
618
+ cache_position=None,
619
+ **kwargs,
620
+ ):
621
+ return MllamaTextSelfAttention_forward(self, hidden_states, attention_mask, position_embeddings, output_attentions, use_cache, past_key_value, cache_position, **kwargs)
622
+
623
+
624
+ @torch.compiler.disable(recursive = False)
625
+ def MllamaTextSelfSdpaAttention_forward(
626
+ self,
627
+ hidden_states: torch.Tensor,
628
+ attention_mask: torch.Tensor,
629
+ position_embeddings: torch.Tensor,
630
+ output_attentions: bool = False,
631
+ use_cache: bool = False,
632
+ past_key_value=None,
633
+ cache_position=None,
634
+ **kwargs,
635
+ ):
636
+ if output_attentions: raise RuntimeError('Unsloth: Not supported')
637
+
638
+ bsz, q_len, _ = hidden_states.size()
639
+
640
+ query_states = self.q_proj(hidden_states)
641
+ key_states = self.k_proj(hidden_states)
642
+ value_states = self.v_proj(hidden_states)
643
+
644
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
645
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
646
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
647
+
648
+ cos, sin = position_embeddings
649
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
650
+
651
+ if past_key_value is not None:
652
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
653
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
654
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
655
+
656
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
657
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
658
+
659
+ causal_mask = attention_mask
660
+ if attention_mask is not None:
661
+ causal_mask = causal_mask[:, :, :, : key_states.shape[-2]]
662
+
663
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
664
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
665
+ if query_states.device.type == "cuda" and causal_mask is not None:
666
+ query_states = query_states.contiguous()
667
+ key_states = key_states.contiguous()
668
+ value_states = value_states.contiguous()
669
+
670
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
671
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
672
+ is_causal = True if causal_mask is None and q_len > 1 else False
673
+
674
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
675
+ query_states,
676
+ key_states,
677
+ value_states,
678
+ attn_mask=causal_mask,
679
+ dropout_p=self.dropout if self.training else 0.0,
680
+ is_causal=is_causal,
681
+ )
682
+
683
+ attn_output = attn_output.transpose(1, 2).contiguous()
684
+ attn_output = attn_output.view(bsz, q_len, -1)
685
+
686
+ attn_output = self.o_proj(attn_output)
687
+ return attn_output, None, past_key_value
688
+
689
+ class MllamaTextSelfSdpaAttention(MllamaTextSelfAttention):
690
+ # Adapted from MllamaTextSelfAttention
691
+ def forward(
692
+ self,
693
+ hidden_states: torch.Tensor,
694
+ attention_mask: torch.Tensor,
695
+ position_embeddings: torch.Tensor,
696
+ output_attentions: bool = False,
697
+ use_cache: bool = False,
698
+ past_key_value=None,
699
+ cache_position=None,
700
+ **kwargs,
701
+ ):
702
+ return MllamaTextSelfSdpaAttention_forward(self, hidden_states, attention_mask, position_embeddings, output_attentions, use_cache, past_key_value, cache_position, **kwargs)
703
+
704
+
705
+ @torch.compile(fullgraph = False, dynamic = True, options = torch_compile_options)
706
+ def MllamaTextMLP_forward(self, x):
707
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
708
+
709
+ class MllamaTextMLP(nn.Module):
710
+ def __init__(self, config):
711
+ super().__init__()
712
+ self.config = config
713
+ self.hidden_size = config.hidden_size
714
+ self.intermediate_size = config.intermediate_size
715
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
716
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
717
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
718
+ # Ignore copy
719
+ self.act_fn = ACT2FN[config.hidden_act]
720
+
721
+ def forward(self, x):
722
+ return MllamaTextMLP_forward(self, x)
723
+
724
+
725
+ @torch.compile(fullgraph = True, dynamic = True, options = torch_compile_options)
726
+ @torch.no_grad()
727
+ def MllamaRotaryEmbedding_forward(self, x, position_ids):
728
+ if "dynamic" in self.rope_type:
729
+ self._dynamic_frequency_update(position_ids, device=x.device)
730
+
731
+ # Core RoPE block
732
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
733
+ position_ids_expanded = position_ids[:, None, :].float()
734
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
735
+ device_type = x.device.type
736
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
737
+ with torch.autocast(device_type=device_type, enabled=False):
738
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
739
+ emb = torch.cat((freqs, freqs), dim=-1)
740
+ cos = emb.cos()
741
+ sin = emb.sin()
742
+
743
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
744
+ cos = cos * self.attention_scaling
745
+ sin = sin * self.attention_scaling
746
+
747
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
748
+
749
+ class MllamaRotaryEmbedding(nn.Module):
750
+ def __init__(self, config: MllamaTextConfig, device=None):
751
+ super().__init__()
752
+ self.rope_type = config.rope_scaling["rope_type"]
753
+ self.max_seq_len_cached = config.max_position_embeddings
754
+ self.original_max_seq_len = config.max_position_embeddings
755
+
756
+ self.config = config
757
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
758
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
759
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
760
+ self.original_inv_freq = self.inv_freq
761
+
762
+ def _dynamic_frequency_update(self, position_ids, device):
763
+ """
764
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
765
+ 1 - growing beyond the cached sequence length (allow scaling)
766
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
767
+ """
768
+ seq_len = torch.max(position_ids) + 1
769
+ if seq_len > self.max_seq_len_cached: # growth
770
+ inv_freq, self.attention_scaling = self.rope_init_fn(
771
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
772
+ )
773
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
774
+ self.max_seq_len_cached = seq_len
775
+
776
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
777
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
778
+ self.max_seq_len_cached = self.original_max_seq_len
779
+
780
+
781
+ def forward(self, x, position_ids):
782
+ return MllamaRotaryEmbedding_forward(self, x, position_ids)