|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModel, AutoConfig |
|
|
|
class EnhancedMoE(nn.Module): |
|
def __init__(self, input_dim, num_experts=12, expert_dim=1024, dropout_rate=0.1): |
|
super(EnhancedMoE, self).__init__() |
|
self.num_experts = num_experts |
|
|
|
self.experts = nn.ModuleList([ |
|
nn.Sequential( |
|
nn.Linear(input_dim, expert_dim), |
|
nn.ReLU(), |
|
nn.Dropout(dropout_rate), |
|
nn.Linear(expert_dim, expert_dim) |
|
) for _ in range(num_experts) |
|
]) |
|
|
|
self.gating_network = nn.Sequential( |
|
nn.Linear(input_dim, expert_dim), |
|
nn.ReLU(), |
|
nn.Linear(expert_dim, num_experts) |
|
) |
|
self.layer_norm = nn.LayerNorm(expert_dim) |
|
|
|
def forward(self, x): |
|
gating_scores = F.softmax(self.gating_network(x), dim=-1) |
|
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1) |
|
output = torch.sum(gating_scores.unsqueeze(-1) * expert_outputs, dim=1) |
|
return self.layer_norm(output) |
|
|
|
class UltraSmarterModel(nn.Module): |
|
def __init__( |
|
self, |
|
text_model_name="bert-base-uncased", |
|
image_dim=2048, |
|
audio_dim=512, |
|
num_classes=None, |
|
hidden_dim=1024 |
|
): |
|
super(UltraSmarterModel, self).__init__() |
|
|
|
|
|
self.text_config = AutoConfig.from_pretrained(text_model_name) |
|
self.text_encoder = AutoModel.from_pretrained(text_model_name) |
|
|
|
|
|
self.image_expert = EnhancedMoE(image_dim, expert_dim=hidden_dim) |
|
self.audio_expert = EnhancedMoE(audio_dim, expert_dim=hidden_dim) |
|
|
|
|
|
self.cross_attention = nn.MultiheadAttention( |
|
embed_dim=hidden_dim, |
|
num_heads=8, |
|
batch_first=True |
|
) |
|
|
|
|
|
fused_dim = hidden_dim * 3 |
|
self.fusion_layer = nn.Sequential( |
|
nn.Linear(fused_dim, hidden_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.1) |
|
) |
|
|
|
|
|
self.output_dim = num_classes if num_classes else hidden_dim |
|
self.output_layer = nn.Linear(hidden_dim, self.output_dim) |
|
|
|
|
|
self.layer_norm = nn.LayerNorm(hidden_dim) |
|
self.dropout = nn.Dropout(0.1) |
|
|
|
def forward(self, text_input, image_input, audio_input): |
|
|
|
text_features = self.text_encoder(**text_input).last_hidden_state[:, 0, :] |
|
text_features = self.dropout(F.relu(text_features)) |
|
|
|
|
|
image_features = self.image_expert(image_input) |
|
audio_features = self.audio_expert(audio_input) |
|
|
|
|
|
text_features = text_features.unsqueeze(1) |
|
image_features = image_features.unsqueeze(1) |
|
audio_features = audio_features.unsqueeze(1) |
|
|
|
|
|
modality_features = torch.cat([text_features, image_features, audio_features], dim=1) |
|
attn_output, _ = self.cross_attention( |
|
modality_features, modality_features, modality_features |
|
) |
|
|
|
|
|
fused_features = attn_output.reshape(attn_output.size(0), -1) |
|
fused_features = self.fusion_layer(fused_features) |
|
fused_features = self.layer_norm(fused_features) |
|
|
|
|
|
output = self.output_layer(fused_features) |
|
|
|
|
|
if self.output_dim > 1: |
|
return F.softmax(output, dim=-1) |
|
return output |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
batch_size = 4 |
|
model = UltraSmarterModel(num_classes=10) |
|
|
|
text_input = { |
|
"input_ids": torch.randint(0, 1000, (batch_size, 128)), |
|
"attention_mask": torch.ones(batch_size, 128) |
|
} |
|
image_input = torch.randn(batch_size, 2048) |
|
audio_input = torch.randn(batch_size, 512) |
|
|
|
|
|
output = model(text_input, image_input, audio_input) |
|
print(f"Output shape: {output.shape}") |