Spaces:
Sleeping
Sleeping
File size: 2,886 Bytes
fbebf66 c227032 fbebf66 c227032 fbebf66 c227032 fbebf66 c227032 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import torch
import torch.nn as nn
from transformers import AutoModel, AutoConfig
class FoundationLayer(nn.Module):
def __init__(self, model_name: str = "gpt2-xl"):
super().__init__()
self.config = AutoConfig.from_pretrained(model_name)
self.transformer = AutoModel.from_pretrained(model_name)
self.sparse_router = MixtureOfExperts(
num_experts=128,
input_size=self.config.hidden_size
)
def forward(self, input_ids, attention_mask=None):
transformer_output = self.transformer(
input_ids=input_ids,
attention_mask=attention_mask
)
routed_output = self.sparse_router(transformer_output.last_hidden_state)
return routed_output # Removed undefined method call
def _process_consciousness_emergence(self, hidden_states):
# Adding the missing method
return hidden_states # Simple implementation, modify as needed
class MixtureOfExperts(nn.Module):
def __init__(self, num_experts: int, input_size: int):
super().__init__()
self.num_experts = num_experts
self.gate = nn.Linear(input_size, num_experts)
self.experts = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=input_size,
nhead=8
) for _ in range(num_experts)
])
def forward(self, hidden_states):
# Adding the missing forward method
batch_size, seq_len, hidden_size = hidden_states.shape
# Calculate routing probabilities
routing_logits = self.gate(hidden_states.view(-1, hidden_size))
routing_probs = torch.softmax(routing_logits, dim=-1)
# Get top-k experts
k = 2 # Using top-2 experts
top_k_probs, top_k_indices = torch.topk(routing_probs, k, dim=-1)
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True) # Normalize
# Reshape for easier processing
hidden_states_flat = hidden_states.view(-1, 1, hidden_size)
# Initialize output
final_output = torch.zeros_like(hidden_states_flat)
# Route to experts
for i, expert in enumerate(self.experts):
# Create mask for this expert
mask = (top_k_indices == i).any(dim=-1).unsqueeze(-1)
if mask.sum() > 0:
# Only process tokens routed to this expert
expert_input = hidden_states_flat[mask.squeeze(-1)]
# Apply expert
expert_output = expert(expert_input)
# Weight by routing probability
weight_mask = (top_k_indices == i).float() * top_k_probs
weight_mask = weight_mask.unsqueeze(-1)
final_output[mask.squeeze(-1)] += expert_output * weight_mask[mask.squeeze(-1)]
return final_output.view(batch_size, seq_len, hidden_size)
|