import torch import torch.nn as nn from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast from transformers.modeling_utils import PreTrainedModel, PretrainedConfig from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union, List import torch.nn.functional as F import math ACT2FN = { "relu": F.relu, "silu": F.silu, "gelu": F.gelu, "tanh": torch.tanh, "sigmoid": torch.sigmoid, } class RasphiDecoderLayer(nn.Module): def __init__(self, config: RasphiConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.reasoning_hidden_size = config.reasoning_hidden_size self.content_hidden_size = config.content_hidden_size # Attention layers self.reasoning_self_attn = RasphiAttention(config, self.reasoning_hidden_size, layer_idx) self.content_self_attn = RasphiAttention(config, self.content_hidden_size, layer_idx) # MoE layers self.reasoning_moe = RasphiSparseMoeBlock(config, is_reasoning=True) self.content_moe = RasphiSparseMoeBlock(config, is_reasoning=False) # Layer norms self.reasoning_input_layernorm = nn.LayerNorm(self.reasoning_hidden_size, eps=config.rms_norm_eps) self.reasoning_post_attention_layernorm = nn.LayerNorm(self.reasoning_hidden_size, eps=config.rms_norm_eps) self.content_input_layernorm = nn.LayerNorm(self.content_hidden_size, eps=config.rms_norm_eps) self.content_post_attention_layernorm = nn.LayerNorm(self.content_hidden_size, eps=config.rms_norm_eps) # Stream interaction self.stream_interaction = config.stream_interaction if self.stream_interaction in ["attention", "both"]: self.reasoning_to_content_attn = RasphiAttention(config, self.content_hidden_size, layer_idx) self.content_to_reasoning_attn = RasphiAttention(config, self.reasoning_hidden_size, layer_idx) if self.stream_interaction in ["mlp", "both"]: self.reasoning_to_content_mlp = nn.Linear(self.reasoning_hidden_size, self.content_hidden_size) self.content_to_reasoning_mlp = nn.Linear(self.content_hidden_size, self.reasoning_hidden_size) def forward( self, reasoning_hidden_states: torch.Tensor, content_hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> Tuple[torch.FloatTensor, ...]: # Self Attention for both streams reasoning_residual = reasoning_hidden_states content_residual = content_hidden_states reasoning_hidden_states = self.reasoning_input_layernorm(reasoning_hidden_states) content_hidden_states = self.content_input_layernorm(content_hidden_states) reasoning_self_attn_output, reasoning_self_attn_weights, reasoning_present_key_value = self.reasoning_self_attn( hidden_states=reasoning_hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value[0] if past_key_value is not None else None, output_attentions=output_attentions, use_cache=use_cache, ) content_self_attn_output, content_self_attn_weights, content_present_key_value = self.content_self_attn( hidden_states=content_hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value[1] if past_key_value is not None else None, output_attentions=output_attentions, use_cache=use_cache, ) reasoning_hidden_states = reasoning_residual + reasoning_self_attn_output content_hidden_states = content_residual + content_self_attn_output # Stream Interaction if self.stream_interaction in ["attention", "both"]: reasoning_to_content, _, _ = self.reasoning_to_content_attn( hidden_states=content_hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=None, output_attentions=False, use_cache=False, key_value_states=reasoning_hidden_states, ) content_to_reasoning, _, _ = self.content_to_reasoning_attn( hidden_states=reasoning_hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=None, output_attentions=False, use_cache=False, key_value_states=content_hidden_states, ) reasoning_hidden_states = reasoning_hidden_states + content_to_reasoning content_hidden_states = content_hidden_states + reasoning_to_content if self.stream_interaction in ["mlp", "both"]: reasoning_to_content = self.reasoning_to_content_mlp(reasoning_hidden_states) content_to_reasoning = self.content_to_reasoning_mlp(content_hidden_states) reasoning_hidden_states = reasoning_hidden_states + content_to_reasoning content_hidden_states = content_hidden_states + reasoning_to_content # MoE for both streams reasoning_residual = reasoning_hidden_states content_residual = content_hidden_states reasoning_hidden_states = self.reasoning_post_attention_layernorm(reasoning_hidden_states) content_hidden_states = self.content_post_attention_layernorm(content_hidden_states) reasoning_moe_output, reasoning_router_logits = self.reasoning_moe(reasoning_hidden_states) content_moe_output, content_router_logits = self.content_moe(content_hidden_states) reasoning_hidden_states = reasoning_residual + reasoning_moe_output content_hidden_states = content_residual + content_moe_output outputs = (reasoning_hidden_states, content_hidden_states) if use_cache: outputs += ((reasoning_present_key_value, content_present_key_value),) if output_attentions: outputs += (reasoning_self_attn_weights, content_self_attn_weights) if output_router_logits: outputs += (reasoning_router_logits, content_router_logits) return outputs class RasphiModel(PreTrainedModel): config_class = RasphiConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["RasphiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True def __init__(self, config: RasphiConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.reasoning_embed_tokens = nn.Embedding(config.vocab_size, config.reasoning_hidden_size, self.padding_idx) self.content_embed_tokens = nn.Embedding(config.vocab_size, config.content_hidden_size, self.padding_idx) self.layers = nn.ModuleList([RasphiDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.reasoning_norm = nn.LayerNorm(config.reasoning_hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) self.content_norm = nn.LayerNorm(config.content_hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return (self.reasoning_embed_tokens, self.content_embed_tokens) def set_input_embeddings(self, value): self.reasoning_embed_tokens = value[0] self.content_embed_tokens = value[1] def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_router_logits = ( output_router_logits if output_router_logits is not None else self.config.output_router_logits ) output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: reasoning_inputs_embeds = self.reasoning_embed_tokens(input_ids) content_inputs_embeds = self.content_embed_tokens(input_ids) else: reasoning_inputs_embeds = inputs_embeds[:, :, :self.config.reasoning_hidden_size] content_inputs_embeds = inputs_embeds[:, :, self.config.reasoning_hidden_size:] reasoning_hidden_states = reasoning_inputs_embeds content_hidden_states = content_inputs_embeds # decoder layers all_reasoning_hidden_states = () if output_hidden_states else None all_content_hidden_states = () if output_hidden_states else None all_reasoning_self_attns = () if output_attentions else None all_content_self_attns = () if output_attentions else None all_reasoning_router_logits = () if output_router_logits else None all_content_router_logits = () if output_router_logits else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_reasoning_hidden_states += (reasoning_hidden_states,) all_content_hidden_states += (content_hidden_states,) layer_outputs = decoder_layer( reasoning_hidden_states, content_hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, ) reasoning_hidden_states = layer_outputs[0] content_hidden_states = layer_outputs[1] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_reasoning_self_attns += (layer_outputs[2],) all_content_self_attns += (layer_outputs[3],) if output_router_logits: all_reasoning_router_logits += (layer_outputs[-2],) all_content_router_logits += (layer_outputs[-1],) reasoning_hidden_states = self.reasoning_norm(reasoning_hidden_states) content_hidden_states = self.content_norm(content_hidden_states) # add hidden states from the last decoder layer if output_hidden_states: all_reasoning_hidden_states += (reasoning_hidden_states,) all_content_hidden_states += (content_hidden_states,) next_cache = None if use_cache: next_cache = next_decoder_cache if not return_dict: return tuple( v for v in [reasoning_hidden_states, content_hidden_states, next_cache, all_reasoning_hidden_states, all_content_hidden_states, all_reasoning_self_attns, all_content_self_attns, all_reasoning_router_logits, all_content_router_logits] if v is not None ) return MoeModelOutputWithPast( last_hidden_state=(reasoning_hidden_states, content_hidden_states), past_key_values=next_cache, hidden_states=(all_reasoning_hidden_states, all_content_hidden_states), attentions=(all_reasoning_self_attns, all_content_self_attns), router_logits=(all_reasoning_router_logits, all_content_router_logits), ) class RasphiSparseMoeBlock(nn.Module): def __init__(self, config: RasphiConfig, is_reasoning: bool): super().__init__() self.hidden_dim = config.reasoning_hidden_size if is_reasoning else config.content_hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_reasoning_experts if is_reasoning else config.num_content_experts self.top_k = config.num_experts_per_tok # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([RasphiBlockSparseTop2MLP(config, is_reasoning) for _ in range(self.num_experts)]) # Jitter parameters self.router_jitter_noise = config.router_jitter_noise self.input_jitter_noise = config.input_jitter_noise def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape if self.training and self.input_jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.input_jitter_noise, 1.0 + self.input_jitter_noise) hidden_states = hidden_states.view(-1, hidden_dim) router_logits = self.gate(hidden_states) routing_weights, selected_experts = sparsemixer( router_logits, top_k=self.top_k, jitter_eps=self.router_jitter_noise, training=self.training, ) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) # One hot encode the selected experts to create an expert mask expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) # Loop over all available experts in the model and perform the computation on each expert for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) if top_x.shape[0] == 0: continue # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x.tolist()].reshape(-1, hidden_dim) current_hidden_states = expert_layer(current_state) * routing_weights[top_x.tolist(), idx.tolist(), None] # Add the expert output to the final hidden states final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits class RasphiBlockSparseTop2MLP(nn.Module): def __init__(self, config: RasphiConfig, is_reasoning: bool): super().__init__() self.ffn_dim = config.intermediate_size self.hidden_dim = config.reasoning_hidden_size if is_reasoning else config.content_hidden_size self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.w3 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states class RasphiPreTrainedModel(PreTrainedModel): config_class = RasphiConfig base_model_prefix = "rasphi" supports_gradient_checkpointing = True _no_split_modules = ["RasphiDecoderLayer"] def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class RasphiForCausalLM(RasphiPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = RasphiModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.content_hidden_size, config.vocab_size, bias=config.lm_head_bias) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_content_experts # We use content experts for language modeling self.num_experts_per_tok = config.num_experts_per_tok # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): return self.model.get_input_embeddings()[1] # Return content embeddings def set_input_embeddings(self, value): self.model.set_input_embeddings((self.model.get_input_embeddings()[0], value)) def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: input_ids = input_ids[:, -1].unsqueeze(-1) if token_type_ids is not None: token_type_ids = token_type_ids[:, -1].unsqueeze(-1) attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -1].unsqueeze(-1) else: position_ids = None return { "input_ids": input_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "position_ids": position_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids, } def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, output_router_logits: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, output_router_logits=output_router_logits, return_dict=return_dict, ) hidden_states = outputs[0] content_hidden_states = hidden_states[1] # Use content stream for language modeling logits = self.lm_head(content_hidden_states) loss = None if labels is not None: # Shift so that tokens < n predict n shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Flatten the tokens loss_fct = CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) aux_loss = None if output_router_logits: aux_loss = load_balancing_loss_func( outputs.router_logits[1] if return_dict else outputs[-1][1], # Use content stream router logits self.num_experts, self.num_experts_per_tok, attention_mask, ) if labels is not None: loss += self.router_aux_loss_coef * aux_loss if not return_dict: output = (logits,) + outputs[1:] return ((loss,) + output) if loss is not None else output return MoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, ) @staticmethod def _reorder_cache(past, beam_idx): return tuple( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) for layer_past in past ) #—Model > Rasphi changes start—# class RasphiAttention(nn.Module): def __init__(self, config: RasphiConfig, hidden_size: int, layer_idx: Optional[int] = None): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = hidden_size self.num_heads = config.num_attention_heads self.head_dim = hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" f" and `num_heads`: {self.num_heads})." ) self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) if getattr(config, 'rope_scaling', None) is None: self.rotary_emb = RasphiMoERotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, base=self.rope_theta, ) else: scaling_type = self.config.rope_scaling["type"] if scaling_type == "linear": self.rotary_emb = LinearScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling["factor"], base=self.rope_theta, ) elif scaling_type == "dynamic": self.rotary_emb = DynamicNTKScalingRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=self.config.rope_scaling["factor"], base=self.rope_theta, ) else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, key_value_states: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) if key_value_states is None: # self-attention key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) else: # cross-attention key_states = self.k_proj(key_value_states) value_states = self.v_proj(key_value_states) kv_len = key_value_states.size(1) query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) past_key_value = (key_states, value_states) if use_cache else None # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None return attn_output, attn_weights, past_key_value class mp(torch.autograd.Function): @staticmethod def forward( ctx, scores: torch.Tensor, multiplier: torch.Tensor, selected_experts: torch.Tensor, masked_gates: torch.Tensor, mask_for_one: torch.Tensor, ): ctx.save_for_backward(multiplier, selected_experts, masked_gates) return multiplier * mask_for_one @staticmethod def backward( ctx, grad_at_output: torch.Tensor, ): multiplier, selected_experts, masked_gates = ctx.saved_tensors grad_at_output = grad_at_output * multiplier grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) grad_at_scores_expaned.scatter_add_( dim=-1, index=selected_experts, src=grad_at_output, ) return ( grad_at_scores_expaned, None, None, None, None, ) def sparsemixer(scores, top_k, jitter_eps, training): assert top_k == 2 ################ first expert ################ with torch.no_grad(): # compute mask for sparsity mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) mask_logits_threshold = ( (mask_logits_threshold - scores) / factor ) > (2 * jitter_eps) # apply mask masked_gates = scores.masked_fill(mask_logits_threshold, float('-inf')) if training: selected_experts = ( masked_gates - torch.empty_like(masked_gates, memory_format=torch.legacy_contiguous_format).exponential_().log() ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method else: selected_experts = max_ind # compute scores for gradients masked_gates = torch.softmax(masked_gates, dim=-1) multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) if training: # compute midpoint mask max_scores, max_ind = masked_gates.max(dim=-1, keepdim=True) mask_for_one = torch.logical_or( selected_experts == max_ind, torch.rand_like(max_scores) > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.) ) # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5 mask_for_one = torch.add(0.3333, mask_for_one, alpha=0.6667).type_as(masked_gates) multiplier = mp.apply( scores, multiplier_o, selected_experts, masked_gates, mask_for_one, ) else: multiplier = multiplier_o # masked out first expert masked_scores = torch.scatter( scores, -1, selected_experts, float('-inf'), ) with torch.no_grad(): # compute mask for sparsity mask_logits_threshold, max_ind = masked_scores.max(dim=-1, keepdim=True) factor = scores.abs().clamp(min=mask_logits_threshold) mask_logits_threshold = ( (mask_logits_threshold - scores) / factor ) > (2 * jitter_eps) # apply mask masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, float('-inf')) if training: selected_experts_top2 = ( masked_gates_top2 - torch.empty_like(masked_gates_top2, memory_format=torch.legacy_contiguous_format).exponential_().log() ).max(dim=-1)[1].unsqueeze(-1) # gumbel sampling, more robust than than the multinomial method else: selected_experts_top2 = max_ind # compute scores for gradients masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) multiplier_top2_o = masked_gates_top2.gather(dim=-1, index=selected_experts_top2) if training: # compute midpoint mask max_scores, max_ind = masked_gates_top2.max(dim=-1, keepdim=True) mask_for_one_top2 = torch.logical_or( selected_experts_top2 == max_ind, torch.rand_like(max_scores).uniform_() > 0.75 # Heun's third-order method: f(x) - f(0) = .25 f'(x) + .75 f'(x/3.) ) # 1 -> 1.0 & 0 -> 1./3: lambda x: (x + 0.5) / 1.5 mask_for_one_top2 = torch.add(0.3333, mask_for_one_top2, alpha=0.6667).type_as(masked_gates_top2) multiplier_top2 = mp.apply( scores, multiplier_top2_o, selected_experts_top2, masked_gates_top2, mask_for_one_top2, ) else: multiplier_top2 = multiplier_top2_o multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) selected_experts = torch.concat((selected_experts, selected_experts_top2), dim=-1) return ( multiplier, selected_experts, ) def load_balancing_loss_func( gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None ) -> float: if gate_logits is None or not isinstance(gate_logits, tuple): return 0 if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) routing_weights = F.softmax(concatenated_gate_logits, dim=-1) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) expert_mask = F.one_hot(selected_experts, num_experts).permute(2, 1, 0) if attention_mask is None: # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.mean(expert_mask.float(), dim=0) # Compute the average probability of routing to these experts router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) .reshape(-1, top_k, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( expert_attention_mask, dim=0 ) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( router_per_expert_attention_mask, dim=0 ) overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) return overall_loss * num_experts class RasphiMoERotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) self.register_buffer("inv_freq", inv_freq) self.max_seq_len_cached = max_position_embeddings t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) def forward(self, x, seq_len=None): if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) class LinearScalingRotaryEmbedding(RasphiMoERotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len t = torch.arange(self.max_seq_len_cached, device=device, dtype=dtype) t = t / self.scaling_factor freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(device) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) class DynamicNTKScalingRotaryEmbedding(RasphiMoERotaryEmbedding): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): self.scaling_factor = scaling_factor super().__init__(dim, max_position_embeddings, base, device) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len if seq_len > self.max_seq_len_cached: base = self.base * ((self.scaling_factor * seq_len / self.max_seq_len_cached) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq) t = torch.arange(self.max_seq_len_cached, device=device, dtype=dtype) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1).to(device) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) def apply_rotary_pos_emb(q, k, cos, sin, position_ids): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) from transformers import AutoModelForCausalLM AutoModelForCausalLM.register("rasphi", RasphiForCausalLM)