Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from transformers import AutoModel, AutoConfig | |
class FoundationLayer(nn.Module): | |
def __init__(self, model_name: str = "gpt2-xl"): | |
super().__init__() | |
self.config = AutoConfig.from_pretrained(model_name) | |
self.transformer = AutoModel.from_pretrained(model_name) | |
self.sparse_router = MixtureOfExperts( | |
num_experts=128, | |
input_size=self.config.hidden_size | |
) | |
def forward(self, input_ids, attention_mask=None): | |
transformer_output = self.transformer( | |
input_ids=input_ids, | |
attention_mask=attention_mask | |
) | |
routed_output = self.sparse_router(transformer_output.last_hidden_state) | |
return routed_output # Removed undefined method call | |
def _process_consciousness_emergence(self, hidden_states): | |
# Adding the missing method | |
return hidden_states # Simple implementation, modify as needed | |
class MixtureOfExperts(nn.Module): | |
def __init__(self, num_experts: int, input_size: int): | |
super().__init__() | |
self.num_experts = num_experts | |
self.gate = nn.Linear(input_size, num_experts) | |
self.experts = nn.ModuleList([ | |
nn.TransformerEncoderLayer( | |
d_model=input_size, | |
nhead=8 | |
) for _ in range(num_experts) | |
]) | |
def forward(self, hidden_states): | |
# Adding the missing forward method | |
batch_size, seq_len, hidden_size = hidden_states.shape | |
# Calculate routing probabilities | |
routing_logits = self.gate(hidden_states.view(-1, hidden_size)) | |
routing_probs = torch.softmax(routing_logits, dim=-1) | |
# Get top-k experts | |
k = 2 # Using top-2 experts | |
top_k_probs, top_k_indices = torch.topk(routing_probs, k, dim=-1) | |
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) # Normalize | |
# Reshape for easier processing | |
hidden_states_flat = hidden_states.view(-1, 1, hidden_size) | |
# Initialize output | |
final_output = torch.zeros_like(hidden_states_flat) | |
# Route to experts | |
for i, expert in enumerate(self.experts): | |
# Create mask for this expert | |
mask = (top_k_indices == i).any(dim=-1).unsqueeze(-1) | |
if mask.sum() > 0: | |
# Only process tokens routed to this expert | |
expert_input = hidden_states_flat[mask.squeeze(-1)] | |
# Apply expert | |
expert_output = expert(expert_input) | |
# Weight by routing probability | |
weight_mask = (top_k_indices == i).float() * top_k_probs | |
weight_mask = weight_mask.unsqueeze(-1) | |
final_output[mask.squeeze(-1)] += expert_output * weight_mask[mask.squeeze(-1)] | |
return final_output.view(batch_size, seq_len, hidden_size) | |