GeminiFan207 commited on
Commit
d9f3a1b
·
verified ·
1 Parent(s): 18fa92b

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +123 -0
model.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel, AutoConfig
5
+
6
+ class EnhancedMoE(nn.Module):
7
+ def __init__(self, input_dim, num_experts=12, expert_dim=1024, dropout_rate=0.1):
8
+ super(EnhancedMoE, self).__init__()
9
+ self.num_experts = num_experts
10
+ # More sophisticated experts with two layers
11
+ self.experts = nn.ModuleList([
12
+ nn.Sequential(
13
+ nn.Linear(input_dim, expert_dim),
14
+ nn.ReLU(),
15
+ nn.Dropout(dropout_rate),
16
+ nn.Linear(expert_dim, expert_dim)
17
+ ) for _ in range(num_experts)
18
+ ])
19
+ # Improved gating with attention-like mechanism
20
+ self.gating_network = nn.Sequential(
21
+ nn.Linear(input_dim, expert_dim),
22
+ nn.ReLU(),
23
+ nn.Linear(expert_dim, num_experts)
24
+ )
25
+ self.layer_norm = nn.LayerNorm(expert_dim)
26
+
27
+ def forward(self, x):
28
+ gating_scores = F.softmax(self.gating_network(x), dim=-1)
29
+ expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
30
+ output = torch.sum(gating_scores.unsqueeze(-1) * expert_outputs, dim=1)
31
+ return self.layer_norm(output)
32
+
33
+ class UltraSmarterModel(nn.Module):
34
+ def __init__(
35
+ self,
36
+ text_model_name="bert-base-uncased",
37
+ image_dim=2048,
38
+ audio_dim=512,
39
+ num_classes=None,
40
+ hidden_dim=1024
41
+ ):
42
+ super(UltraSmarterModel, self).__init__()
43
+
44
+ # Text processing
45
+ self.text_config = AutoConfig.from_pretrained(text_model_name)
46
+ self.text_encoder = AutoModel.from_pretrained(text_model_name)
47
+
48
+ # Enhanced modality experts
49
+ self.image_expert = EnhancedMoE(image_dim, expert_dim=hidden_dim)
50
+ self.audio_expert = EnhancedMoE(audio_dim, expert_dim=hidden_dim)
51
+
52
+ # Cross-attention between modalities
53
+ self.cross_attention = nn.MultiheadAttention(
54
+ embed_dim=hidden_dim,
55
+ num_heads=8,
56
+ batch_first=True
57
+ )
58
+
59
+ # Fusion and output
60
+ fused_dim = hidden_dim * 3 # Text + Image + Audio
61
+ self.fusion_layer = nn.Sequential(
62
+ nn.Linear(fused_dim, hidden_dim),
63
+ nn.ReLU(),
64
+ nn.Dropout(0.1)
65
+ )
66
+
67
+ # Flexible output layer (classification or regression)
68
+ self.output_dim = num_classes if num_classes else hidden_dim
69
+ self.output_layer = nn.Linear(hidden_dim, self.output_dim)
70
+
71
+ # Additional improvements
72
+ self.layer_norm = nn.LayerNorm(hidden_dim)
73
+ self.dropout = nn.Dropout(0.1)
74
+
75
+ def forward(self, text_input, image_input, audio_input):
76
+ # Text features from CLS token
77
+ text_features = self.text_encoder(**text_input).last_hidden_state[:, 0, :]
78
+ text_features = self.dropout(F.relu(text_features))
79
+
80
+ # Process image and audio through enhanced MoE
81
+ image_features = self.image_expert(image_input)
82
+ audio_features = self.audio_expert(audio_input)
83
+
84
+ # Reshape for cross-attention (batch_size, seq_len=1, embed_dim)
85
+ text_features = text_features.unsqueeze(1)
86
+ image_features = image_features.unsqueeze(1)
87
+ audio_features = audio_features.unsqueeze(1)
88
+
89
+ # Cross-attention between modalities
90
+ modality_features = torch.cat([text_features, image_features, audio_features], dim=1)
91
+ attn_output, _ = self.cross_attention(
92
+ modality_features, modality_features, modality_features
93
+ )
94
+
95
+ # Fuse features
96
+ fused_features = attn_output.reshape(attn_output.size(0), -1)
97
+ fused_features = self.fusion_layer(fused_features)
98
+ fused_features = self.layer_norm(fused_features)
99
+
100
+ # Final output
101
+ output = self.output_layer(fused_features)
102
+
103
+ # Apply softmax/sigmoid if classification
104
+ if self.output_dim > 1:
105
+ return F.softmax(output, dim=-1)
106
+ return output
107
+
108
+ # Example usage
109
+ if __name__ == "__main__":
110
+ # Sample inputs
111
+ batch_size = 4
112
+ model = UltraSmarterModel(num_classes=10) # For 10-class classification
113
+
114
+ text_input = {
115
+ "input_ids": torch.randint(0, 1000, (batch_size, 128)),
116
+ "attention_mask": torch.ones(batch_size, 128)
117
+ }
118
+ image_input = torch.randn(batch_size, 2048)
119
+ audio_input = torch.randn(batch_size, 512)
120
+
121
+ # Forward pass
122
+ output = model(text_input, image_input, audio_input)
123
+ print(f"Output shape: {output.shape}") # Should be [batch_size, 10]