Question Answering
Transformers
English
Chinese
multimodal
vqa
text
audio
Eval Results
Inference Endpoints
zeroMN commited on
Commit
2df5d25
·
verified ·
1 Parent(s): a65107a

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +224 -212
model.py CHANGED
@@ -1,212 +1,224 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import os
5
- # 配置类定义
6
- class Config:
7
- def __init__(self):
8
- # 模型架构参数
9
- self.hidden_size = 768
10
- self.num_attention_heads = 12
11
- self.num_hidden_layers = 12
12
- self.intermediate_size = 3072
13
- self.hidden_dropout_prob = 0.1
14
- self.attention_probs_dropout_prob = 0.1
15
-
16
- # 图像相关
17
- self.image_size = 224
18
- self.image_channels = 3
19
- self.patch_size = 16
20
-
21
- # 文本相关
22
- self.max_position_embeddings = 512
23
- self.vocab_size = 30522
24
- self.type_vocab_size = 2
25
-
26
- # 语音相关
27
- self.audio_sample_rate = 16000
28
- self.audio_frame_size = 1024
29
- self.audio_hop_size = 512
30
-
31
- # 任务相关
32
- self.enable_vqa = True
33
- self.enable_caption = True
34
- self.enable_retrieval = True
35
- self.enable_asr = True # 语音识别
36
- self.enable_realtime_asr = True # 实时语音识别
37
-
38
- # 训练相关
39
- self.batch_size = 32
40
- self.learning_rate = 1e-4
41
- self.weight_decay = 0.01
42
- self.warmup_steps = 10000
43
- self.max_steps = 100000
44
-
45
- # 模型相关类定义
46
- class ImageEncoder(nn.Module):
47
- def __init__(self, config):
48
- super(ImageEncoder, self).__init__()
49
- self.config = config
50
- self.encoder_layer = nn.Sequential(
51
- nn.Conv2d(3, 64, kernel_size=3),
52
- nn.ReLU(),
53
- nn.MaxPool2d(2, 2),
54
- nn.Flatten(),
55
- nn.Linear(64 * 111 * 111, config.hidden_size)
56
- )
57
-
58
- def forward(self, image):
59
- image_features = self.encoder_layer(image)
60
- return image_features
61
-
62
- class TextEncoder(nn.Module):
63
- def __init__(self, config):
64
- super(TextEncoder, self).__init__()
65
- self.config = config
66
- self.transformer_layer = nn.TransformerEncoderLayer(
67
- d_model=config.hidden_size,
68
- nhead=config.num_attention_heads,
69
- batch_first=True
70
- )
71
- self.transformer_encoder = nn.TransformerEncoder(
72
- self.transformer_layer,
73
- num_layers=config.num_hidden_layers
74
- )
75
-
76
- def forward(self, text):
77
- text_features = self.transformer_encoder(text).mean(dim=1)
78
- return text_features
79
-
80
- class AudioEncoder(nn.Module):
81
- def __init__(self, config):
82
- super(AudioEncoder, self).__init__()
83
- self.config = config
84
- self.encoder_layer = nn.Sequential(
85
- nn.Linear(config.audio_sample_rate, config.hidden_size),
86
- nn.ReLU(),
87
- nn.Linear(config.hidden_size, config.hidden_size)
88
- )
89
-
90
- def forward(self, audio):
91
- audio_features = self.encoder_layer(audio)
92
- return audio_features
93
-
94
- class FusionLayer(nn.Module):
95
- def __init__(self, config):
96
- super(FusionLayer, self).__init__()
97
- self.config = config
98
- self.fusion_layer = nn.Linear(config.hidden_size * 3, config.hidden_size)
99
-
100
- def forward(self, image_features, text_features, audio_features):
101
- fused_features = torch.cat((image_features, text_features, audio_features), dim=1)
102
- fused_features = self.fusion_layer(fused_features)
103
- return fused_features
104
-
105
- class VQALayer(nn.Module):
106
- def __init__(self, config):
107
- super(VQALayer, self).__init__()
108
- self.config = config
109
- self.vqa_layer = nn.Linear(config.hidden_size, config.vocab_size)
110
-
111
- def forward(self, fused_features):
112
- vqa_output = self.vqa_layer(fused_features)
113
- return vqa_output
114
-
115
- class CaptionLayer(nn.Module):
116
- def __init__(self, config):
117
- super(CaptionLayer, self).__init__()
118
- self.config = config
119
- self.caption_layer = nn.Linear(config.hidden_size, config.vocab_size)
120
-
121
- def forward(self, fused_features):
122
- caption_output = self.caption_layer(fused_features)
123
- return caption_output
124
-
125
- class RetrievalLayer(nn.Module):
126
- def __init__(self, config):
127
- super(RetrievalLayer, self).__init__()
128
- self.config = config
129
- self.retrieval_layer = nn.Linear(config.hidden_size, config.vocab_size)
130
-
131
- def forward(self, fused_features):
132
- retrieval_output = self.retrieval_layer(fused_features)
133
- return retrieval_output
134
-
135
- class ASRLayer(nn.Module):
136
- def __init__(self, config):
137
- super(ASRLayer, self).__init__()
138
- self.config = config
139
- self.asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
140
-
141
- def forward(self, fused_features):
142
- asr_output = self.asr_layer(fused_features)
143
- return asr_output
144
-
145
- class RealtimeASRLayer(nn.Module):
146
- def __init__(self, config):
147
- super(RealtimeASRLayer, self).__init__()
148
- self.config = config
149
- self.realtime_asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
150
-
151
- def forward(self, fused_features):
152
- realtime_asr_output = self.realtime_asr_layer(fused_features)
153
- return realtime_asr_output
154
-
155
- # 主模型定义
156
- class AutoModel(nn.Module):
157
- def __init__(self, config):
158
- super(AutoModel, self).__init__()
159
- self.config = config
160
- self.image_encoder = ImageEncoder(config)
161
- self.text_encoder = TextEncoder(config)
162
- self.audio_encoder = AudioEncoder(config)
163
- self.fusion_layer = FusionLayer(config)
164
- self.vqa_layer = VQALayer(config)
165
- self.caption_layer = CaptionLayer(config)
166
- self.retrieval_layer = RetrievalLayer(config)
167
- self.asr_layer = ASRLayer(config)
168
- self.realtime_asr_layer = RealtimeASRLayer(config)
169
-
170
- def forward(self, image, text, audio):
171
- image_features = self.image_encoder(image)
172
- text_features = self.text_encoder(text)
173
- audio_features = self.audio_encoder(audio)
174
- fused_features = self.fusion_layer(image_features, text_features, audio_features)
175
- vqa_output = self.vqa_layer(fused_features)
176
- caption_output = self.caption_layer(fused_features)
177
- retrieval_output = self.retrieval_layer(fused_features)
178
- asr_output = self.asr_layer(fused_features)
179
- realtime_asr_output = self.realtime_asr_layer(fused_features)
180
- return vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output
181
-
182
- # 测试代码
183
- config = Config()
184
- model = AutoModel(config)
185
- image = torch.randn(1, 3, 224, 224)
186
- text = torch.randn(1, config.max_position_embeddings, config.hidden_size)
187
- audio = torch.randn(1, config.audio_sample_rate)
188
- vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output = model(image, text, audio)
189
-
190
- # 输出结果
191
- print("VQA output shape:", vqa_output.shape)
192
- print("Caption output shape:", caption_output.shape)
193
- print("Retrieval output shape:", retrieval_output.shape)
194
- print("ASR output shape:", asr_output.shape)
195
- print("Realtime ASR output shape:", realtime_asr_output.shape)
196
-
197
- # 打印总参数数量
198
- total_params = sum(p.numel() for p in model.parameters())
199
- print(f"\n总参数数量: {total_params}")
200
-
201
- # 定义保存路径
202
- save_dir = "./" # 当前目录
203
- os.makedirs(save_dir, exist_ok=True)
204
- save_path = os.path.join(save_dir, "AutoModel.pth")
205
-
206
- # 保存模型权重
207
- torch.save(model.state_dict(), save_path)
208
- print(f"模型权重已保存到: {save_path}")
209
-
210
-
211
-
212
-
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ # 配置类定义
6
+ class Config:
7
+ def __init__(self):
8
+ # 模型架构参数
9
+ self.hidden_size = 768
10
+ self.num_attention_heads = 12
11
+ self.num_hidden_layers = 12
12
+ self.intermediate_size = 3072
13
+ self.hidden_dropout_prob = 0.1
14
+ self.attention_probs_dropout_prob = 0.1
15
+
16
+ # 图像相关
17
+ self.image_size = 224
18
+ self.image_channels = 3
19
+ self.patch_size = 16
20
+
21
+ # 文本相关
22
+ self.max_position_embeddings = 512
23
+ self.vocab_size = 30522
24
+ self.type_vocab_size = 2
25
+
26
+ # 语音相关
27
+ self.audio_sample_rate = 16000
28
+ self.audio_frame_size = 1024
29
+ self.audio_hop_size = 512
30
+
31
+ # 任务相关
32
+ self.enable_vqa = True
33
+ self.enable_caption = True
34
+ self.enable_retrieval = True
35
+ self.enable_asr = True # 语音识别
36
+ self.enable_realtime_asr = True # 实时语音识别
37
+
38
+ # 训练相关
39
+ self.batch_size = 32
40
+ self.learning_rate = 1e-4
41
+ self.weight_decay = 0.01
42
+ self.warmup_steps = 10000
43
+ self.max_steps = 100000
44
+
45
+ # 模型相关类定义
46
+ class ImageEncoder(nn.Module):
47
+ def __init__(self, config):
48
+ super(ImageEncoder, self).__init__()
49
+ self.config = config
50
+ self.encoder_layer = nn.Sequential(
51
+ nn.Conv2d(3, 64, kernel_size=3),
52
+ nn.ReLU(),
53
+ nn.MaxPool2d(2, 2),
54
+ nn.Flatten(),
55
+ nn.Linear(64 * 111 * 111, config.hidden_size)
56
+ )
57
+
58
+ def forward(self, image):
59
+ image_features = self.encoder_layer(image)
60
+ return image_features
61
+
62
+ class TextEncoder(nn.Module):
63
+ def __init__(self, config):
64
+ super(TextEncoder, self).__init__()
65
+ self.config = config
66
+ self.transformer_layer = nn.TransformerEncoderLayer(
67
+ d_model=config.hidden_size,
68
+ nhead=config.num_attention_heads,
69
+ batch_first=True
70
+ )
71
+ self.transformer_encoder = nn.TransformerEncoder(
72
+ self.transformer_layer,
73
+ num_layers=config.num_hidden_layers
74
+ )
75
+
76
+ def forward(self, text):
77
+ text_features = self.transformer_encoder(text).mean(dim=1)
78
+ return text_features
79
+
80
+ class AudioEncoder(nn.Module):
81
+ def __init__(self, config):
82
+ super(AudioEncoder, self).__init__()
83
+ self.config = config
84
+ self.encoder_layer = nn.Sequential(
85
+ nn.Linear(config.audio_sample_rate, config.hidden_size),
86
+ nn.ReLU(),
87
+ nn.Linear(config.hidden_size, config.hidden_size)
88
+ )
89
+
90
+ def forward(self, audio):
91
+ audio_features = self.encoder_layer(audio)
92
+ return audio_features
93
+
94
+ class FusionLayer(nn.Module):
95
+ def __init__(self, config):
96
+ super(FusionLayer, self).__init__()
97
+ self.config = config
98
+ self.fusion_layer = nn.Linear(config.hidden_size * 3, config.hidden_size)
99
+
100
+ def forward(self, image_features, text_features, audio_features):
101
+ fused_features = torch.cat((image_features, text_features, audio_features), dim=1)
102
+ fused_features = self.fusion_layer(fused_features)
103
+ return fused_features
104
+
105
+ class VQALayer(nn.Module):
106
+ def __init__(self, config):
107
+ super(VQALayer, self).__init__()
108
+ self.config = config
109
+ self.vqa_layer = nn.Linear(config.hidden_size, config.vocab_size)
110
+
111
+ def forward(self, fused_features):
112
+ vqa_output = self.vqa_layer(fused_features)
113
+ return vqa_output
114
+
115
+ class CaptionLayer(nn.Module):
116
+ def __init__(self, config):
117
+ super(CaptionLayer, self).__init__()
118
+ self.config = config
119
+ self.caption_layer = nn.Linear(config.hidden_size, config.vocab_size)
120
+
121
+ def forward(self, fused_features):
122
+ caption_output = self.caption_layer(fused_features)
123
+ return caption_output
124
+
125
+ class RetrievalLayer(nn.Module):
126
+ def __init__(self, config):
127
+ super(RetrievalLayer, self).__init__()
128
+ self.config = config
129
+ self.retrieval_layer = nn.Linear(config.hidden_size, config.vocab_size)
130
+
131
+ def forward(self, fused_features):
132
+ retrieval_output = self.retrieval_layer(fused_features)
133
+ return retrieval_output
134
+
135
+ class ASRLayer(nn.Module):
136
+ def __init__(self, config):
137
+ super(ASRLayer, self).__init__()
138
+ self.config = config
139
+ self.asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
140
+
141
+ def forward(self, fused_features):
142
+ asr_output = self.asr_layer(fused_features)
143
+ return asr_output
144
+
145
+ class RealtimeASRLayer(nn.Module):
146
+ def __init__(self, config):
147
+ super(RealtimeASRLayer, self).__init__()
148
+ self.config = config
149
+ self.realtime_asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
150
+
151
+ def forward(self, fused_features):
152
+ realtime_asr_output = self.realtime_asr_layer(fused_features)
153
+ return realtime_asr_output
154
+
155
+ class TextOutputLayer(nn.Module):
156
+ def __init__(self, config):
157
+ super(TextOutputLayer, self).__init__()
158
+ self.config = config
159
+ self.text_output_layer = nn.Linear(config.hidden_size, config.vocab_size)
160
+
161
+ def forward(self, fused_features):
162
+ text_output = self.text_output_layer(fused_features)
163
+ return text_output
164
+
165
+ # 主模型定义
166
+ class AutoModel(nn.Module):
167
+ def __init__(self, config):
168
+ super(AutoModel, self).__init__()
169
+ self.config = config
170
+ self.image_encoder = ImageEncoder(config)
171
+ self.text_encoder = TextEncoder(config)
172
+ self.audio_encoder = AudioEncoder(config)
173
+ self.fusion_layer = FusionLayer(config)
174
+ self.vqa_layer = VQALayer(config)
175
+ self.caption_layer = CaptionLayer(config)
176
+ self.retrieval_layer = RetrievalLayer(config)
177
+ self.asr_layer = ASRLayer(config)
178
+ self.realtime_asr_layer = RealtimeASRLayer(config)
179
+ self.text_output_layer = TextOutputLayer(config)
180
+
181
+ def forward(self, image, text, audio):
182
+ image_features = self.image_encoder(image)
183
+ text_features = self.text_encoder(text)
184
+ audio_features = self.audio_encoder(audio)
185
+ fused_features = self.fusion_layer(image_features, text_features, audio_features)
186
+ vqa_output = self.vqa_layer(fused_features)
187
+ caption_output = self.caption_layer(fused_features)
188
+ retrieval_output = self.retrieval_layer(fused_features)
189
+ asr_output = self.asr_layer(fused_features)
190
+ realtime_asr_output = self.realtime_asr_layer(fused_features)
191
+ text_output = self.text_output_layer(fused_features)
192
+ return vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output, text_output
193
+
194
+ # 测试代码
195
+ config = Config()
196
+ model = AutoModel(config)
197
+ image = torch.randn(1, 3, 224, 224)
198
+ text = torch.randn(1, config.max_position_embeddings, config.hidden_size)
199
+ audio = torch.randn(1, config.audio_sample_rate)
200
+ vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output, text_output = model(image, text, audio)
201
+
202
+ # 输出结果
203
+ print("VQA output shape:", vqa_output.shape)
204
+ print("Caption output shape:", caption_output.shape)
205
+ print("Retrieval output shape:", retrieval_output.shape)
206
+ print("ASR output shape:", asr_output.shape)
207
+ print("Realtime ASR output shape:", realtime_asr_output.shape)
208
+ print("Text output shape:", text_output.shape)
209
+
210
+ # 打印总参数数量
211
+ total_params = sum(p.numel() for p in model.parameters())
212
+ print(f"\n总参数数量: {total_params}")
213
+
214
+
215
+
216
+ # 保存模型权重
217
+ save_path = "save.pth"
218
+ torch.save(model.state_dict(), save_path)
219
+ print(f"模型权重已保存到: {save_path}")
220
+
221
+
222
+
223
+
224
+