QuietImpostor commited on
Commit
e187c98
1 Parent(s): 3b22110

Upload code to run Rasphi

Browse files
Files changed (2) hide show
  1. configuration_rasphi.py +137 -0
  2. modeling_rasphi.py +908 -0
configuration_rasphi.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.modeling_utils import PretrainedConfig
2
+
3
+ class RasphiConfig(PretrainedConfig):
4
+ model_type = "rasphi"
5
+ keys_to_ignore_at_inference = ["past_key_values"]
6
+
7
+ def __init__(
8
+ self,
9
+ vocab_size=32064,
10
+ hidden_size=4096,
11
+ intermediate_size=6400,
12
+ num_hidden_layers=32,
13
+ num_attention_heads=32,
14
+ num_key_value_heads=8,
15
+ hidden_act="silu",
16
+ max_position_embeddings=4096 * 32,
17
+ initializer_range=0.02,
18
+ rms_norm_eps=1e-5,
19
+ use_cache=True,
20
+ pad_token_id=None,
21
+ bos_token_id=1,
22
+ eos_token_id=2,
23
+ tie_word_embeddings=False,
24
+ rope_theta=1e6,
25
+ rope_scaling=None,
26
+ sliding_window=None,
27
+ attention_dropout=0.0,
28
+ num_experts_per_tok=2,
29
+ num_local_experts=16,
30
+ output_router_logits=False,
31
+ router_aux_loss_coef=0.001,
32
+ router_jitter_noise=0.01,
33
+ input_jitter_noise=0.0,
34
+ attention_bias=False,
35
+ lm_head_bias=False,
36
+ # Rasphi specific configurations
37
+ num_reasoning_experts=8, # Number of experts dedicated to reasoning stream
38
+ num_content_experts=8, # Number of experts dedicated to content stream
39
+ reasoning_hidden_size=2048, # Hidden size for reasoning stream
40
+ content_hidden_size=2048, # Hidden size for content stream
41
+ stream_interaction="attention", # How the two streams interact: "attention", "mlp", or "both"
42
+ **kwargs,
43
+ ):
44
+ self.vocab_size = vocab_size
45
+ self.max_position_embeddings = max_position_embeddings
46
+ self.hidden_size = hidden_size
47
+ self.intermediate_size = intermediate_size
48
+ self.num_hidden_layers = num_hidden_layers
49
+ self.num_attention_heads = num_attention_heads
50
+ self.sliding_window = sliding_window
51
+ self.attention_bias = attention_bias
52
+ self.lm_head_bias = lm_head_bias
53
+ self.num_key_value_heads = num_key_value_heads
54
+ self.hidden_act = hidden_act
55
+ self.initializer_range = initializer_range
56
+ self.rms_norm_eps = rms_norm_eps
57
+ self.use_cache = use_cache
58
+ self.rope_theta = rope_theta
59
+ self.attention_dropout = attention_dropout
60
+ self.num_experts_per_tok = num_experts_per_tok
61
+ self.num_local_experts = num_local_experts
62
+ self.output_router_logits = output_router_logits
63
+ self.router_aux_loss_coef = router_aux_loss_coef
64
+ self.router_jitter_noise = router_jitter_noise
65
+ self.input_jitter_noise = input_jitter_noise
66
+ self.rope_scaling = rope_scaling
67
+ self._rope_scaling_validation()
68
+
69
+ # Rasphi specific configurations
70
+ self.num_reasoning_experts = num_reasoning_experts
71
+ self.num_content_experts = num_content_experts
72
+ self.reasoning_hidden_size = reasoning_hidden_size
73
+ self.content_hidden_size = content_hidden_size
74
+ self.stream_interaction = stream_interaction
75
+
76
+ super().__init__(
77
+ pad_token_id=pad_token_id,
78
+ bos_token_id=bos_token_id,
79
+ eos_token_id=eos_token_id,
80
+ tie_word_embeddings=tie_word_embeddings,
81
+ **kwargs,
82
+ )
83
+
84
+ def _rope_scaling_validation(self):
85
+ """
86
+ Validate the `rope_scaling` configuration.
87
+ """
88
+ if self.rope_scaling is None:
89
+ return
90
+
91
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 6:
92
+ raise ValueError(
93
+ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor`, `long_factor`, "
94
+ f"`short_mscale`, `long_mscale` and `original_max_position_embeddings`, got {self.rope_scaling}"
95
+ )
96
+ rope_scaling_type = self.rope_scaling.get("type", None)
97
+ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None)
98
+ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None)
99
+ rope_scaling_short_mscale = self.rope_scaling.get("short_mscale", None)
100
+ rope_scaling_long_mscale = self.rope_scaling.get("long_mscale", None)
101
+ original_max_position_embeddings = self.rope_scaling.get("original_max_position_embeddings", None)
102
+ if rope_scaling_type is None or rope_scaling_type not in ["longrope"]:
103
+ raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}")
104
+ if not (
105
+ isinstance(rope_scaling_short_factor, list)
106
+ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor)
107
+ ):
108
+ raise ValueError(
109
+ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}"
110
+ )
111
+ if not len(rope_scaling_short_factor) == self.hidden_size // self.num_attention_heads // 2:
112
+ raise ValueError(
113
+ f"`rope_scaling`'s short_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_short_factor)}"
114
+ )
115
+ if not (
116
+ isinstance(rope_scaling_long_factor, list)
117
+ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor)
118
+ ):
119
+ raise ValueError(
120
+ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}"
121
+ )
122
+ if not len(rope_scaling_long_factor) == self.hidden_size // self.num_attention_heads // 2:
123
+ raise ValueError(
124
+ f"`rope_scaling`'s long_factor field must have length {self.hidden_size // self.num_attention_heads // 2}, got {len(rope_scaling_long_factor)}"
125
+ )
126
+ if not isinstance(rope_scaling_short_mscale, (int, float)):
127
+ raise ValueError(
128
+ f"`rope_scaling`'s short_mscale field must be a number, got {rope_scaling_short_mscale}"
129
+ )
130
+ if not isinstance(rope_scaling_long_mscale, (int, float)):
131
+ raise ValueError(
132
+ f"`rope_scaling`'s long_mscale field must be a number, got {rope_scaling_long_mscale}"
133
+ )
134
+ if not isinstance(original_max_position_embeddings, int):
135
+ raise ValueError(
136
+ f"`rope_scaling`'s original_max_position_embeddings field must be an integer, got {original_max_position_embeddings}"
137
+ )
modeling_rasphi.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast
4
+ from transformers.modeling_utils import PreTrainedModel, PretrainedConfig
5
+ from torch.nn import CrossEntropyLoss
6
+ from typing import Optional, Tuple, Union, List
7
+ import torch.nn.functional as F
8
+ import math
9
+
10
+ ACT2FN = {
11
+ "relu": F.relu,
12
+ "silu": F.silu,
13
+ "gelu": F.gelu,
14
+ "tanh": torch.tanh,
15
+ "sigmoid": torch.sigmoid,
16
+ }
17
+
18
+ class RasphiDecoderLayer(nn.Module):
19
+ def __init__(self, config: RasphiConfig, layer_idx: int):
20
+ super().__init__()
21
+ self.layer_idx = layer_idx
22
+ self.hidden_size = config.hidden_size
23
+ self.reasoning_hidden_size = config.reasoning_hidden_size
24
+ self.content_hidden_size = config.content_hidden_size
25
+
26
+ # Attention layers
27
+ self.reasoning_self_attn = RasphiAttention(config, self.reasoning_hidden_size, layer_idx)
28
+ self.content_self_attn = RasphiAttention(config, self.content_hidden_size, layer_idx)
29
+
30
+ # MoE layers
31
+ self.reasoning_moe = RasphiSparseMoeBlock(config, is_reasoning=True)
32
+ self.content_moe = RasphiSparseMoeBlock(config, is_reasoning=False)
33
+
34
+ # Layer norms
35
+ self.reasoning_input_layernorm = nn.LayerNorm(self.reasoning_hidden_size, eps=config.rms_norm_eps)
36
+ self.reasoning_post_attention_layernorm = nn.LayerNorm(self.reasoning_hidden_size, eps=config.rms_norm_eps)
37
+ self.content_input_layernorm = nn.LayerNorm(self.content_hidden_size, eps=config.rms_norm_eps)
38
+ self.content_post_attention_layernorm = nn.LayerNorm(self.content_hidden_size, eps=config.rms_norm_eps)
39
+
40
+ # Stream interaction
41
+ self.stream_interaction = config.stream_interaction
42
+ if self.stream_interaction in ["attention", "both"]:
43
+ self.reasoning_to_content_attn = RasphiAttention(config, self.content_hidden_size, layer_idx)
44
+ self.content_to_reasoning_attn = RasphiAttention(config, self.reasoning_hidden_size, layer_idx)
45
+ if self.stream_interaction in ["mlp", "both"]:
46
+ self.reasoning_to_content_mlp = nn.Linear(self.reasoning_hidden_size, self.content_hidden_size)
47
+ self.content_to_reasoning_mlp = nn.Linear(self.content_hidden_size, self.reasoning_hidden_size)
48
+
49
+ def forward(
50
+ self,
51
+ reasoning_hidden_states: torch.Tensor,
52
+ content_hidden_states: torch.Tensor,
53
+ attention_mask: Optional[torch.Tensor] = None,
54
+ position_ids: Optional[torch.LongTensor] = None,
55
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
56
+ output_attentions: Optional[bool] = False,
57
+ output_router_logits: Optional[bool] = False,
58
+ use_cache: Optional[bool] = False,
59
+ ) -> Tuple[torch.FloatTensor, ...]:
60
+ # Self Attention for both streams
61
+ reasoning_residual = reasoning_hidden_states
62
+ content_residual = content_hidden_states
63
+
64
+ reasoning_hidden_states = self.reasoning_input_layernorm(reasoning_hidden_states)
65
+ content_hidden_states = self.content_input_layernorm(content_hidden_states)
66
+
67
+ reasoning_self_attn_output, reasoning_self_attn_weights, reasoning_present_key_value = self.reasoning_self_attn(
68
+ hidden_states=reasoning_hidden_states,
69
+ attention_mask=attention_mask,
70
+ position_ids=position_ids,
71
+ past_key_value=past_key_value[0] if past_key_value is not None else None,
72
+ output_attentions=output_attentions,
73
+ use_cache=use_cache,
74
+ )
75
+
76
+ content_self_attn_output, content_self_attn_weights, content_present_key_value = self.content_self_attn(
77
+ hidden_states=content_hidden_states,
78
+ attention_mask=attention_mask,
79
+ position_ids=position_ids,
80
+ past_key_value=past_key_value[1] if past_key_value is not None else None,
81
+ output_attentions=output_attentions,
82
+ use_cache=use_cache,
83
+ )
84
+
85
+ reasoning_hidden_states = reasoning_residual + reasoning_self_attn_output
86
+ content_hidden_states = content_residual + content_self_attn_output
87
+
88
+ # Stream Interaction
89
+ if self.stream_interaction in ["attention", "both"]:
90
+ reasoning_to_content, _, _ = self.reasoning_to_content_attn(
91
+ hidden_states=content_hidden_states,
92
+ attention_mask=attention_mask,
93
+ position_ids=position_ids,
94
+ past_key_value=None,
95
+ output_attentions=False,
96
+ use_cache=False,
97
+ key_value_states=reasoning_hidden_states,
98
+ )
99
+ content_to_reasoning, _, _ = self.content_to_reasoning_attn(
100
+ hidden_states=reasoning_hidden_states,
101
+ attention_mask=attention_mask,
102
+ position_ids=position_ids,
103
+ past_key_value=None,
104
+ output_attentions=False,
105
+ use_cache=False,
106
+ key_value_states=content_hidden_states,
107
+ )
108
+ reasoning_hidden_states = reasoning_hidden_states + content_to_reasoning
109
+ content_hidden_states = content_hidden_states + reasoning_to_content
110
+
111
+ if self.stream_interaction in ["mlp", "both"]:
112
+ reasoning_to_content = self.reasoning_to_content_mlp(reasoning_hidden_states)
113
+ content_to_reasoning = self.content_to_reasoning_mlp(content_hidden_states)
114
+ reasoning_hidden_states = reasoning_hidden_states + content_to_reasoning
115
+ content_hidden_states = content_hidden_states + reasoning_to_content
116
+
117
+ # MoE for both streams
118
+ reasoning_residual = reasoning_hidden_states
119
+ content_residual = content_hidden_states
120
+
121
+ reasoning_hidden_states = self.reasoning_post_attention_layernorm(reasoning_hidden_states)
122
+ content_hidden_states = self.content_post_attention_layernorm(content_hidden_states)
123
+
124
+ reasoning_moe_output, reasoning_router_logits = self.reasoning_moe(reasoning_hidden_states)
125
+ content_moe_output, content_router_logits = self.content_moe(content_hidden_states)
126
+
127
+ reasoning_hidden_states = reasoning_residual + reasoning_moe_output
128
+ content_hidden_states = content_residual + content_moe_output
129
+
130
+ outputs = (reasoning_hidden_states, content_hidden_states)
131
+
132
+ if use_cache:
133
+ outputs += ((reasoning_present_key_value, content_present_key_value),)
134
+ if output_attentions:
135
+ outputs += (reasoning_self_attn_weights, content_self_attn_weights)
136
+ if output_router_logits:
137
+ outputs += (reasoning_router_logits, content_router_logits)
138
+
139
+ return outputs
140
+
141
+ class RasphiModel(PreTrainedModel):
142
+ config_class = RasphiConfig
143
+ base_model_prefix = "model"
144
+ supports_gradient_checkpointing = True
145
+ _no_split_modules = ["RasphiDecoderLayer"]
146
+ _skip_keys_device_placement = "past_key_values"
147
+ _supports_flash_attn_2 = True
148
+ _supports_sdpa = True
149
+ _supports_cache_class = True
150
+
151
+ def __init__(self, config: RasphiConfig):
152
+ super().__init__(config)
153
+ self.padding_idx = config.pad_token_id
154
+ self.vocab_size = config.vocab_size
155
+
156
+ self.reasoning_embed_tokens = nn.Embedding(config.vocab_size, config.reasoning_hidden_size, self.padding_idx)
157
+ self.content_embed_tokens = nn.Embedding(config.vocab_size, config.content_hidden_size, self.padding_idx)
158
+
159
+ self.layers = nn.ModuleList([RasphiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
160
+
161
+ self.reasoning_norm = nn.LayerNorm(config.reasoning_hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
162
+ self.content_norm = nn.LayerNorm(config.content_hidden_size, eps=config.rms_norm_eps, elementwise_affine=True)
163
+
164
+ self.gradient_checkpointing = False
165
+
166
+ # Initialize weights and apply final processing
167
+ self.post_init()
168
+
169
+ def get_input_embeddings(self):
170
+ return (self.reasoning_embed_tokens, self.content_embed_tokens)
171
+
172
+ def set_input_embeddings(self, value):
173
+ self.reasoning_embed_tokens = value[0]
174
+ self.content_embed_tokens = value[1]
175
+
176
+ def forward(
177
+ self,
178
+ input_ids: torch.LongTensor = None,
179
+ attention_mask: Optional[torch.Tensor] = None,
180
+ position_ids: Optional[torch.LongTensor] = None,
181
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
182
+ inputs_embeds: Optional[torch.FloatTensor] = None,
183
+ use_cache: Optional[bool] = None,
184
+ output_attentions: Optional[bool] = None,
185
+ output_hidden_states: Optional[bool] = None,
186
+ output_router_logits: Optional[bool] = None,
187
+ return_dict: Optional[bool] = None,
188
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
189
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
190
+ output_router_logits = (
191
+ output_router_logits if output_router_logits is not None else self.config.output_router_logits
192
+ )
193
+ output_hidden_states = (
194
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
195
+ )
196
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
197
+
198
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
199
+
200
+ # retrieve input_ids and inputs_embeds
201
+ if input_ids is not None and inputs_embeds is not None:
202
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
203
+ elif input_ids is not None:
204
+ batch_size, seq_length = input_ids.shape
205
+ elif inputs_embeds is not None:
206
+ batch_size, seq_length, _ = inputs_embeds.shape
207
+ else:
208
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
209
+
210
+ if inputs_embeds is None:
211
+ reasoning_inputs_embeds = self.reasoning_embed_tokens(input_ids)
212
+ content_inputs_embeds = self.content_embed_tokens(input_ids)
213
+ else:
214
+ reasoning_inputs_embeds = inputs_embeds[:, :, :self.config.reasoning_hidden_size]
215
+ content_inputs_embeds = inputs_embeds[:, :, self.config.reasoning_hidden_size:]
216
+
217
+ reasoning_hidden_states = reasoning_inputs_embeds
218
+ content_hidden_states = content_inputs_embeds
219
+
220
+ # decoder layers
221
+ all_reasoning_hidden_states = () if output_hidden_states else None
222
+ all_content_hidden_states = () if output_hidden_states else None
223
+ all_reasoning_self_attns = () if output_attentions else None
224
+ all_content_self_attns = () if output_attentions else None
225
+ all_reasoning_router_logits = () if output_router_logits else None
226
+ all_content_router_logits = () if output_router_logits else None
227
+ next_decoder_cache = None
228
+
229
+ for decoder_layer in self.layers:
230
+ if output_hidden_states:
231
+ all_reasoning_hidden_states += (reasoning_hidden_states,)
232
+ all_content_hidden_states += (content_hidden_states,)
233
+
234
+ layer_outputs = decoder_layer(
235
+ reasoning_hidden_states,
236
+ content_hidden_states,
237
+ attention_mask=attention_mask,
238
+ position_ids=position_ids,
239
+ past_key_value=past_key_values,
240
+ output_attentions=output_attentions,
241
+ output_router_logits=output_router_logits,
242
+ use_cache=use_cache,
243
+ )
244
+
245
+ reasoning_hidden_states = layer_outputs[0]
246
+ content_hidden_states = layer_outputs[1]
247
+
248
+ if use_cache:
249
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
250
+
251
+ if output_attentions:
252
+ all_reasoning_self_attns += (layer_outputs[2],)
253
+ all_content_self_attns += (layer_outputs[3],)
254
+
255
+ if output_router_logits:
256
+ all_reasoning_router_logits += (layer_outputs[-2],)
257
+ all_content_router_logits += (layer_outputs[-1],)
258
+
259
+ reasoning_hidden_states = self.reasoning_norm(reasoning_hidden_states)
260
+ content_hidden_states = self.content_norm(content_hidden_states)
261
+
262
+ # add hidden states from the last decoder layer
263
+ if output_hidden_states:
264
+ all_reasoning_hidden_states += (reasoning_hidden_states,)
265
+ all_content_hidden_states += (content_hidden_states,)
266
+
267
+ next_cache = None
268
+ if use_cache:
269
+ next_cache = next_decoder_cache
270
+
271
+ if not return_dict:
272
+ return tuple(
273
+ v
274
+ for v in [reasoning_hidden_states, content_hidden_states, next_cache, all_reasoning_hidden_states,
275
+ all_content_hidden_states, all_reasoning_self_attns, all_content_self_attns,
276
+ all_reasoning_router_logits, all_content_router_logits]
277
+ if v is not None
278
+ )
279
+
280
+ return MoeModelOutputWithPast(
281
+ last_hidden_state=(reasoning_hidden_states, content_hidden_states),
282
+ past_key_values=next_cache,
283
+ hidden_states=(all_reasoning_hidden_states, all_content_hidden_states),
284
+ attentions=(all_reasoning_self_attns, all_content_self_attns),
285
+ router_logits=(all_reasoning_router_logits, all_content_router_logits),
286
+ )
287
+
288
+ class RasphiSparseMoeBlock(nn.Module):
289
+ def __init__(self, config: RasphiConfig, is_reasoning: bool):
290
+ super().__init__()
291
+ self.hidden_dim = config.reasoning_hidden_size if is_reasoning else config.content_hidden_size
292
+ self.ffn_dim = config.intermediate_size
293
+ self.num_experts = config.num_reasoning_experts if is_reasoning else config.num_content_experts
294
+ self.top_k = config.num_experts_per_tok
295
+
296
+ # gating
297
+ self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
298
+
299
+ self.experts = nn.ModuleList([RasphiBlockSparseTop2MLP(config, is_reasoning) for _ in range(self.num_experts)])
300
+
301
+ # Jitter parameters
302
+ self.router_jitter_noise = config.router_jitter_noise
303
+ self.input_jitter_noise = config.input_jitter_noise
304
+
305
+ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
306
+ batch_size, sequence_length, hidden_dim = hidden_states.shape
307
+ if self.training and self.input_jitter_noise > 0:
308
+ hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise)
309
+ hidden_states = hidden_states.view(-1, hidden_dim)
310
+
311
+ router_logits = self.gate(hidden_states)
312
+
313
+ routing_weights, selected_experts = sparsemixer(
314
+ router_logits,
315
+ top_k=self.top_k,
316
+ jitter_eps=self.router_jitter_noise,
317
+ training=self.training,
318
+ )
319
+
320
+ final_hidden_states = torch.zeros(
321
+ (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
322
+ )
323
+
324
+ # One hot encode the selected experts to create an expert mask
325
+ expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
326
+
327
+ # Loop over all available experts in the model and perform the computation on each expert
328
+ for expert_idx in range(self.num_experts):
329
+ expert_layer = self.experts[expert_idx]
330
+ idx, top_x = torch.where(expert_mask[expert_idx])
331
+
332
+ if top_x.shape[0] == 0:
333
+ continue
334
+
335
+ # Index the correct hidden states and compute the expert hidden state for
336
+ # the current expert. We need to make sure to multiply the output hidden
337
+ # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
338
+ current_state = hidden_states[None, top_x.tolist()].reshape(-1, hidden_dim)
339
+ current_hidden_states = expert_layer(current_state) * routing_weights[top_x.tolist(), idx.tolist(), None]
340
+
341
+ # Add the expert output to the final hidden states
342
+ final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
343
+
344
+ final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
345
+ return final_hidden_states, router_logits
346
+
347
+ class RasphiBlockSparseTop2MLP(nn.Module):
348
+ def __init__(self, config: RasphiConfig, is_reasoning: bool):
349
+ super().__init__()
350
+ self.ffn_dim = config.intermediate_size
351
+ self.hidden_dim = config.reasoning_hidden_size if is_reasoning else config.content_hidden_size
352
+
353
+ self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
354
+ self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
355
+ self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
356
+
357
+ self.act_fn = ACT2FN[config.hidden_act]
358
+
359
+ def forward(self, hidden_states):
360
+ current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states)
361
+ current_hidden_states = self.w2(current_hidden_states)
362
+ return current_hidden_states
363
+
364
+ class RasphiPreTrainedModel(PreTrainedModel):
365
+ config_class = RasphiConfig
366
+ base_model_prefix = "rasphi"
367
+ supports_gradient_checkpointing = True
368
+ _no_split_modules = ["RasphiDecoderLayer"]
369
+
370
+ def _init_weights(self, module):
371
+ std = self.config.initializer_range
372
+ if isinstance(module, nn.Linear):
373
+ module.weight.data.normal_(mean=0.0, std=std)
374
+ if module.bias is not None:
375
+ module.bias.data.zero_()
376
+ elif isinstance(module, nn.Embedding):
377
+ module.weight.data.normal_(mean=0.0, std=std)
378
+ if module.padding_idx is not None:
379
+ module.weight.data[module.padding_idx].zero_()
380
+
381
+ class RasphiForCausalLM(RasphiPreTrainedModel):
382
+ _tied_weights_keys = ["lm_head.weight"]
383
+
384
+ def __init__(self, config):
385
+ super().__init__(config)
386
+ self.model = RasphiModel(config)
387
+ self.vocab_size = config.vocab_size
388
+ self.lm_head = nn.Linear(config.content_hidden_size, config.vocab_size, bias=config.lm_head_bias)
389
+ self.router_aux_loss_coef = config.router_aux_loss_coef
390
+ self.num_experts = config.num_content_experts # We use content experts for language modeling
391
+ self.num_experts_per_tok = config.num_experts_per_tok
392
+
393
+ # Initialize weights and apply final processing
394
+ self.post_init()
395
+
396
+ def get_input_embeddings(self):
397
+ return self.model.get_input_embeddings()[1] # Return content embeddings
398
+
399
+ def set_input_embeddings(self, value):
400
+ self.model.set_input_embeddings((self.model.get_input_embeddings()[0], value))
401
+
402
+ def get_output_embeddings(self):
403
+ return self.lm_head
404
+
405
+ def set_output_embeddings(self, new_embeddings):
406
+ self.lm_head = new_embeddings
407
+
408
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
409
+ token_type_ids = kwargs.get("token_type_ids", None)
410
+ # only last token for inputs_ids if past is defined in kwargs
411
+ if past_key_values:
412
+ input_ids = input_ids[:, -1].unsqueeze(-1)
413
+ if token_type_ids is not None:
414
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
415
+
416
+ attention_mask = kwargs.get("attention_mask", None)
417
+ position_ids = kwargs.get("position_ids", None)
418
+
419
+ if attention_mask is not None and position_ids is None:
420
+ # create position_ids on the fly for batch generation
421
+ position_ids = attention_mask.long().cumsum(-1) - 1
422
+ position_ids.masked_fill_(attention_mask == 0, 1)
423
+ if past_key_values:
424
+ position_ids = position_ids[:, -1].unsqueeze(-1)
425
+ else:
426
+ position_ids = None
427
+
428
+ return {
429
+ "input_ids": input_ids,
430
+ "past_key_values": past_key_values,
431
+ "use_cache": kwargs.get("use_cache"),
432
+ "position_ids": position_ids,
433
+ "attention_mask": attention_mask,
434
+ "token_type_ids": token_type_ids,
435
+ }
436
+
437
+ def forward(
438
+ self,
439
+ input_ids: Optional[torch.LongTensor] = None,
440
+ attention_mask: Optional[torch.FloatTensor] = None,
441
+ position_ids: Optional[torch.LongTensor] = None,
442
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
443
+ inputs_embeds: Optional[torch.FloatTensor] = None,
444
+ labels: Optional[torch.LongTensor] = None,
445
+ use_cache: Optional[bool] = None,
446
+ output_attentions: Optional[bool] = None,
447
+ output_hidden_states: Optional[bool] = None,
448
+ output_router_logits: Optional[bool] = None,
449
+ return_dict: Optional[bool] = None,
450
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
451
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
452
+
453
+ outputs = self.model(
454
+ input_ids,
455
+ attention_mask=attention_mask,
456
+ position_ids=position_ids,
457
+ past_key_values=past_key_values,
458
+ inputs_embeds=inputs_embeds,
459
+ use_cache=use_cache,
460
+ output_attentions=output_attentions,
461
+ output_hidden_states=output_hidden_states,
462
+ output_router_logits=output_router_logits,
463
+ return_dict=return_dict,
464
+ )
465
+
466
+ hidden_states = outputs[0]
467
+ content_hidden_states = hidden_states[1] # Use content stream for language modeling
468
+ logits = self.lm_head(content_hidden_states)
469
+
470
+ loss = None
471
+ if labels is not None:
472
+ # Shift so that tokens < n predict n
473
+ shift_logits = logits[..., :-1, :].contiguous()
474
+ shift_labels = labels[..., 1:].contiguous()
475
+ # Flatten the tokens
476
+ loss_fct = CrossEntropyLoss()
477
+ loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
478
+
479
+ aux_loss = None
480
+ if output_router_logits:
481
+ aux_loss = load_balancing_loss_func(
482
+ outputs.router_logits[1] if return_dict else outputs[-1][1], # Use content stream router logits
483
+ self.num_experts,
484
+ self.num_experts_per_tok,
485
+ attention_mask,
486
+ )
487
+ if labels is not None:
488
+ loss += self.router_aux_loss_coef * aux_loss
489
+
490
+ if not return_dict:
491
+ output = (logits,) + outputs[1:]
492
+ return ((loss,) + output) if loss is not None else output
493
+
494
+ return MoeCausalLMOutputWithPast(
495
+ loss=loss,
496
+ aux_loss=aux_loss,
497
+ logits=logits,
498
+ past_key_values=outputs.past_key_values,
499
+ hidden_states=outputs.hidden_states,
500
+ attentions=outputs.attentions,
501
+ router_logits=outputs.router_logits,
502
+ )
503
+
504
+ @staticmethod
505
+ def _reorder_cache(past, beam_idx):
506
+ return tuple(
507
+ tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
508
+ for layer_past in past
509
+ )
510
+
511
+
512
+ #—Model > Rasphi changes start—#
513
+ class RasphiAttention(nn.Module):
514
+ def __init__(self, config: RasphiConfig, hidden_size: int, layer_idx: Optional[int] = None):
515
+ super().__init__()
516
+ self.config = config
517
+ self.layer_idx = layer_idx
518
+ self.hidden_size = hidden_size
519
+ self.num_heads = config.num_attention_heads
520
+ self.head_dim = hidden_size // self.num_heads
521
+ self.num_key_value_heads = config.num_key_value_heads
522
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
523
+ self.max_position_embeddings = config.max_position_embeddings
524
+ self.rope_theta = config.rope_theta
525
+ self.is_causal = True
526
+ self.attention_dropout = config.attention_dropout
527
+
528
+ if (self.head_dim * self.num_heads) != self.hidden_size:
529
+ raise ValueError(
530
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
531
+ f" and `num_heads`: {self.num_heads})."
532
+ )
533
+
534
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
535
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
536
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
537
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
538
+
539
+ if getattr(config, 'rope_scaling', None) is None:
540
+ self.rotary_emb = RasphiMoERotaryEmbedding(
541
+ self.head_dim,
542
+ max_position_embeddings=self.max_position_embeddings,
543
+ base=self.rope_theta,
544
+ )
545
+ else:
546
+ scaling_type = self.config.rope_scaling["type"]
547
+ if scaling_type == "linear":
548
+ self.rotary_emb = LinearScalingRotaryEmbedding(
549
+ self.head_dim,
550
+ max_position_embeddings=self.max_position_embeddings,
551
+ scaling_factor=self.config.rope_scaling["factor"],
552
+ base=self.rope_theta,
553
+ )
554
+ elif scaling_type == "dynamic":
555
+ self.rotary_emb = DynamicNTKScalingRotaryEmbedding(
556
+ self.head_dim,
557
+ max_position_embeddings=self.max_position_embeddings,
558
+ scaling_factor=self.config.rope_scaling["factor"],
559
+ base=self.rope_theta,
560
+ )
561
+ else:
562
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
563
+
564
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
565
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
566
+
567
+ def forward(
568
+ self,
569
+ hidden_states: torch.Tensor,
570
+ attention_mask: Optional[torch.Tensor] = None,
571
+ position_ids: Optional[torch.LongTensor] = None,
572
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
573
+ output_attentions: bool = False,
574
+ use_cache: bool = False,
575
+ key_value_states: Optional[torch.Tensor] = None,
576
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
577
+ bsz, q_len, _ = hidden_states.size()
578
+
579
+ query_states = self.q_proj(hidden_states)
580
+
581
+ if key_value_states is None:
582
+ # self-attention
583
+ key_states = self.k_proj(hidden_states)
584
+ value_states = self.v_proj(hidden_states)
585
+ else:
586
+ # cross-attention
587
+ key_states = self.k_proj(key_value_states)
588
+ value_states = self.v_proj(key_value_states)
589
+ kv_len = key_value_states.size(1)
590
+
591
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
592
+ key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
593
+ value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2)
594
+
595
+ kv_seq_len = key_states.shape[-2]
596
+ if past_key_value is not None:
597
+ kv_seq_len += past_key_value[0].shape[-2]
598
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
599
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
600
+
601
+ if past_key_value is not None:
602
+ # reuse k, v, self_attention
603
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
604
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
605
+
606
+ past_key_value = (key_states, value_states) if use_cache else None
607
+
608
+ # repeat k/v heads if n_kv_heads < n_heads
609
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
610
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
611
+
612
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
613
+
614
+ if attention_mask is not None:
615
+ attn_weights = attn_weights + attention_mask
616
+
617
+ # upcast attention to fp32
618
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
619
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
620
+ attn_output = torch.matmul(attn_weights, value_states)
621
+
622
+ attn_output = attn_output.transpose(1, 2).contiguous()
623
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
624
+
625
+ attn_output = self.o_proj(attn_output)
626
+
627
+ if not output_attentions:
628
+ attn_weights = None
629
+
630
+ return attn_output, attn_weights, past_key_value
631
+
632
+ class mp(torch.autograd.Function):
633
+ @staticmethod
634
+ def forward(
635
+ ctx,
636
+ scores: torch.Tensor,
637
+ multiplier: torch.Tensor,
638
+ selected_experts: torch.Tensor,
639
+ masked_gates: torch.Tensor,
640
+ mask_for_one: torch.Tensor,
641
+ ):
642
+ ctx.save_for_backward(multiplier, selected_experts, masked_gates)
643
+ return multiplier * mask_for_one
644
+
645
+ @staticmethod
646
+ def backward(
647
+ ctx,
648
+ grad_at_output: torch.Tensor,
649
+ ):
650
+ multiplier, selected_experts, masked_gates = ctx.saved_tensors
651
+
652
+ grad_at_output = grad_at_output * multiplier
653
+
654
+ grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1)
655
+ grad_at_scores_expaned.scatter_add_(
656
+ dim=-1,
657
+ index=selected_experts,
658
+ src=grad_at_output,
659
+ )
660
+
661
+ return (
662
+ grad_at_scores_expaned,
663
+ None,
664
+ None,
665
+ None,
666
+ None,
667
+ )
668
+
669
+ def sparsemixer(scores, top_k, jitter_eps, training):
670
+ assert top_k == 2
671
+
672
+ ################ first expert ################
673
+
674
+ with torch.no_grad():
675
+ # compute mask for sparsity
676
+ mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True)
677
+ factor = scores.abs().clamp(min=mask_logits_threshold)
678
+ mask_logits_threshold = (
679
+ (mask_logits_threshold - scores) / factor
680
+ ) > (2 * jitter_eps)
681
+
682
+ # apply mask
683
+ masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf'))
684
+ if training:
685
+ selected_experts = (
686
+ masked_gates - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log()
687
+ ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method
688
+ else:
689
+ selected_experts = max_ind
690
+
691
+ # compute scores for gradients
692
+ masked_gates = torch.softmax(masked_gates, dim=-1)
693
+ multiplier_o = masked_gates.gather(dim=-1, index=selected_experts)
694
+
695
+ if training:
696
+ # compute midpoint mask
697
+ max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True)
698
+ mask_for_one = torch.logical_or(
699
+ selected_experts == max_ind,
700
+ torch.rand_like(max_scores) > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
701
+ )
702
+ # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
703
+ mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates)
704
+
705
+ multiplier = mp.apply(
706
+ scores,
707
+ multiplier_o,
708
+ selected_experts,
709
+ masked_gates,
710
+ mask_for_one,
711
+ )
712
+ else:
713
+ multiplier = multiplier_o
714
+
715
+ # masked out first expert
716
+ masked_scores = torch.scatter(
717
+ scores,
718
+ -1,
719
+ selected_experts,
720
+ float('-inf'),
721
+ )
722
+ with torch.no_grad():
723
+ # compute mask for sparsity
724
+ mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True)
725
+ factor = scores.abs().clamp(min=mask_logits_threshold)
726
+ mask_logits_threshold = (
727
+ (mask_logits_threshold - scores) / factor
728
+ ) > (2 * jitter_eps)
729
+
730
+ # apply mask
731
+ masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf'))
732
+ if training:
733
+ selected_experts_top2 = (
734
+ masked_gates_top2 - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format).exponential_().log()
735
+ ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method
736
+ else:
737
+ selected_experts_top2 = max_ind
738
+ # compute scores for gradients
739
+ masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1)
740
+ multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2)
741
+
742
+ if training:
743
+ # compute midpoint mask
744
+ max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True)
745
+ mask_for_one_top2 = torch.logical_or(
746
+ selected_experts_top2 == max_ind,
747
+ torch.rand_like(max_scores).uniform_() > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.)
748
+ )
749
+ # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5
750
+ mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2)
751
+
752
+ multiplier_top2 = mp.apply(
753
+ scores,
754
+ multiplier_top2_o,
755
+ selected_experts_top2,
756
+ masked_gates_top2,
757
+ mask_for_one_top2,
758
+ )
759
+ else:
760
+ multiplier_top2 = multiplier_top2_o
761
+
762
+ multiplier = torch.concat((multiplier, multiplier_top2), dim=-1)
763
+ selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1)
764
+
765
+ return (
766
+ multiplier,
767
+ selected_experts,
768
+ )
769
+
770
+ def load_balancing_loss_func(
771
+ gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None
772
+ ) -> float:
773
+ if gate_logits is None or not isinstance(gate_logits, tuple):
774
+ return 0
775
+
776
+ if isinstance(gate_logits, tuple):
777
+ compute_device = gate_logits[0].device
778
+ concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0)
779
+
780
+ routing_weights = F.softmax(concatenated_gate_logits, dim=-1)
781
+
782
+ _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
783
+
784
+ expert_mask = F.one_hot(selected_experts, num_experts).permute(2, 1, 0)
785
+
786
+ if attention_mask is None:
787
+ # Compute the percentage of tokens routed to each experts
788
+ tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
789
+
790
+ # Compute the average probability of routing to these experts
791
+ router_prob_per_expert = torch.mean(routing_weights, dim=0)
792
+ else:
793
+ batch_size, sequence_length = attention_mask.shape
794
+ num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
795
+
796
+ # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
797
+ expert_attention_mask = (
798
+ attention_mask[None, :, :, None, None]
799
+ .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts))
800
+ .reshape(-1, top_k, num_experts)
801
+ .to(compute_device)
802
+ )
803
+
804
+ # Compute the percentage of tokens routed to each experts
805
+ tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(
806
+ expert_attention_mask, dim=0
807
+ )
808
+
809
+ # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
810
+ router_per_expert_attention_mask = (
811
+ attention_mask[None, :, :, None]
812
+ .expand((num_hidden_layers, batch_size, sequence_length, num_experts))
813
+ .reshape(-1, num_experts)
814
+ .to(compute_device)
815
+ )
816
+
817
+ # Compute the average probability of routing to these experts
818
+ router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(
819
+ router_per_expert_attention_mask, dim=0
820
+ )
821
+
822
+ overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
823
+ return overall_loss * num_experts
824
+
825
+ class RasphiMoERotaryEmbedding(nn.Module):
826
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
827
+ super().__init__()
828
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
829
+ self.register_buffer("inv_freq", inv_freq)
830
+ self.max_seq_len_cached = max_position_embeddings
831
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
832
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
833
+ emb = torch.cat((freqs, freqs), dim=-1)
834
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
835
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
836
+
837
+ def forward(self, x, seq_len=None):
838
+ if seq_len > self.max_seq_len_cached:
839
+ self._set_cos_sin_cache(seq_len, device=x.device, dtype=x.dtype)
840
+
841
+ return (
842
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
843
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
844
+ )
845
+
846
+ class LinearScalingRotaryEmbedding(RasphiMoERotaryEmbedding):
847
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
848
+ self.scaling_factor = scaling_factor
849
+ super().__init__(dim, max_position_embeddings, base, device)
850
+
851
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
852
+ self.max_seq_len_cached = seq_len
853
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=dtype)
854
+ t = t / self.scaling_factor
855
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
856
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
857
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
858
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
859
+
860
+ class DynamicNTKScalingRotaryEmbedding(RasphiMoERotaryEmbedding):
861
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
862
+ self.scaling_factor = scaling_factor
863
+ super().__init__(dim, max_position_embeddings, base, device)
864
+
865
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
866
+ self.max_seq_len_cached = seq_len
867
+
868
+ if seq_len > self.max_seq_len_cached:
869
+ base = self.base * ((self.scaling_factor * seq_len / self.max_seq_len_cached) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
870
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
871
+ self.register_buffer("inv_freq", inv_freq)
872
+
873
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=dtype)
874
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
875
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
876
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
877
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
878
+
879
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
880
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
881
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
882
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
883
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
884
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
885
+ q_embed = (q * cos) + (rotate_half(q) * sin)
886
+ k_embed = (k * cos) + (rotate_half(k) * sin)
887
+ return q_embed, k_embed
888
+
889
+ def rotate_half(x):
890
+ """Rotates half the hidden dims of the input."""
891
+ x1 = x[..., : x.shape[-1] // 2]
892
+ x2 = x[..., x.shape[-1] // 2 :]
893
+ return torch.cat((-x2, x1), dim=-1)
894
+
895
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
896
+ """
897
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
898
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
899
+ """
900
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
901
+ if n_rep == 1:
902
+ return hidden_states
903
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
904
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
905
+
906
+ from transformers import AutoModelForCausalLM
907
+
908
+ AutoModelForCausalLM.register("rasphi", RasphiForCausalLM)