Question Answering
Transformers
English
Chinese
multimodal
vqa
text
audio
Eval Results
Inference Endpoints
zeroMN commited on
Commit
7724139
·
verified ·
1 Parent(s): 2df5d25

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +220 -62
README.md CHANGED
@@ -100,69 +100,227 @@ Hugging Face Hub 本身不能自动运行上传的模型,但通过 `Spaces`
100
 
101
  ## Uses
102
  ```python
103
- import os
104
  import torch
105
- from model import AutoModel, Config
106
-
107
- def load_model(model_path, config_path):
108
- """
109
- 加载模型权重和配置
110
- """
111
- # 加载配置
112
- if not os.path.exists(config_path):
113
- raise FileNotFoundError(f"配置文件未找到: {config_path}")
114
- print(f"加载配置文件: {config_path}")
115
- config = Config()
116
-
117
- # 初始化模型
118
- model = AutoModel(config)
119
-
120
- # 加载权重
121
- if not os.path.exists(model_path):
122
- raise FileNotFoundError(f"模型文件未找到: {model_path}")
123
- print(f"加载模型权重: {model_path}")
124
- state_dict = torch.load(model_path, map_location=torch.device("cpu"))
125
- model.load_state_dict(state_dict)
126
- model.eval()
127
- print("模型加载成功并设置为评估模式。")
128
-
129
- return model, config
130
-
131
-
132
- def run_inference(model, config):
133
- """
134
- 使用模型运行推理
135
- """
136
- # 模拟示例输入
137
- image = torch.randn(1, 3, 224, 224) # 图像输入
138
- text = torch.randn(1, config.max_position_embeddings, config.hidden_size) # 文本输入
139
- audio = torch.randn(1, config.audio_sample_rate) # 音频输入
140
-
141
- # 模型推理
142
- outputs = model(image, text, audio)
143
- vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output = outputs
144
-
145
- # 打印结果
146
- print("\n推理结果:")
147
- print(f"VQA output shape: {vqa_output.shape}")
148
- print(f"Caption output shape: {caption_output.shape}")
149
- print(f"Retrieval output shape: {retrieval_output.shape}")
150
- print(f"ASR output shape: {asr_output.shape}")
151
- print(f"Realtime ASR output shape: {realtime_asr_output.shape}")
152
-
153
- if __name__ == "__main__":
154
- # 文件路径
155
- model_path = "AutoModel.pth"
156
- config_path = "config.json"
157
-
158
- # 加载模型
159
- try:
160
- model, config = load_model(model_path, config_path)
161
-
162
- # 运行推理
163
- run_inference(model, config)
164
- except Exception as e:
165
- print(f"运行失败: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  ```
167
 
168
  ### Direct Use
 
100
 
101
  ## Uses
102
  ```python
 
103
  import torch
104
+ import torch.nn as nn
105
+ import torch.nn.functional as F
106
+ import os
107
+ # 配置类定义
108
+ class Config:
109
+ def __init__(self):
110
+ # 模型架构参数
111
+ self.hidden_size = 768
112
+ self.num_attention_heads = 12
113
+ self.num_hidden_layers = 12
114
+ self.intermediate_size = 3072
115
+ self.hidden_dropout_prob = 0.1
116
+ self.attention_probs_dropout_prob = 0.1
117
+
118
+ # 图像相关
119
+ self.image_size = 224
120
+ self.image_channels = 3
121
+ self.patch_size = 16
122
+
123
+ # 文本相关
124
+ self.max_position_embeddings = 512
125
+ self.vocab_size = 30522
126
+ self.type_vocab_size = 2
127
+
128
+ # 语音相关
129
+ self.audio_sample_rate = 16000
130
+ self.audio_frame_size = 1024
131
+ self.audio_hop_size = 512
132
+
133
+ # 任务相关
134
+ self.enable_vqa = True
135
+ self.enable_caption = True
136
+ self.enable_retrieval = True
137
+ self.enable_asr = True # 语音识别
138
+ self.enable_realtime_asr = True # 实时语音识别
139
+
140
+ # 训练相关
141
+ self.batch_size = 32
142
+ self.learning_rate = 1e-4
143
+ self.weight_decay = 0.01
144
+ self.warmup_steps = 10000
145
+ self.max_steps = 100000
146
+
147
+ # 模型相关类定义
148
+ class ImageEncoder(nn.Module):
149
+ def __init__(self, config):
150
+ super(ImageEncoder, self).__init__()
151
+ self.config = config
152
+ self.encoder_layer = nn.Sequential(
153
+ nn.Conv2d(3, 64, kernel_size=3),
154
+ nn.ReLU(),
155
+ nn.MaxPool2d(2, 2),
156
+ nn.Flatten(),
157
+ nn.Linear(64 * 111 * 111, config.hidden_size)
158
+ )
159
+
160
+ def forward(self, image):
161
+ image_features = self.encoder_layer(image)
162
+ return image_features
163
+
164
+ class TextEncoder(nn.Module):
165
+ def __init__(self, config):
166
+ super(TextEncoder, self).__init__()
167
+ self.config = config
168
+ self.transformer_layer = nn.TransformerEncoderLayer(
169
+ d_model=config.hidden_size,
170
+ nhead=config.num_attention_heads,
171
+ batch_first=True
172
+ )
173
+ self.transformer_encoder = nn.TransformerEncoder(
174
+ self.transformer_layer,
175
+ num_layers=config.num_hidden_layers
176
+ )
177
+
178
+ def forward(self, text):
179
+ text_features = self.transformer_encoder(text).mean(dim=1)
180
+ return text_features
181
+
182
+ class AudioEncoder(nn.Module):
183
+ def __init__(self, config):
184
+ super(AudioEncoder, self).__init__()
185
+ self.config = config
186
+ self.encoder_layer = nn.Sequential(
187
+ nn.Linear(config.audio_sample_rate, config.hidden_size),
188
+ nn.ReLU(),
189
+ nn.Linear(config.hidden_size, config.hidden_size)
190
+ )
191
+
192
+ def forward(self, audio):
193
+ audio_features = self.encoder_layer(audio)
194
+ return audio_features
195
+
196
+ class FusionLayer(nn.Module):
197
+ def __init__(self, config):
198
+ super(FusionLayer, self).__init__()
199
+ self.config = config
200
+ self.fusion_layer = nn.Linear(config.hidden_size * 3, config.hidden_size)
201
+
202
+ def forward(self, image_features, text_features, audio_features):
203
+ fused_features = torch.cat((image_features, text_features, audio_features), dim=1)
204
+ fused_features = self.fusion_layer(fused_features)
205
+ return fused_features
206
+
207
+ class VQALayer(nn.Module):
208
+ def __init__(self, config):
209
+ super(VQALayer, self).__init__()
210
+ self.config = config
211
+ self.vqa_layer = nn.Linear(config.hidden_size, config.vocab_size)
212
+
213
+ def forward(self, fused_features):
214
+ vqa_output = self.vqa_layer(fused_features)
215
+ return vqa_output
216
+
217
+ class CaptionLayer(nn.Module):
218
+ def __init__(self, config):
219
+ super(CaptionLayer, self).__init__()
220
+ self.config = config
221
+ self.caption_layer = nn.Linear(config.hidden_size, config.vocab_size)
222
+
223
+ def forward(self, fused_features):
224
+ caption_output = self.caption_layer(fused_features)
225
+ return caption_output
226
+
227
+ class RetrievalLayer(nn.Module):
228
+ def __init__(self, config):
229
+ super(RetrievalLayer, self).__init__()
230
+ self.config = config
231
+ self.retrieval_layer = nn.Linear(config.hidden_size, config.vocab_size)
232
+
233
+ def forward(self, fused_features):
234
+ retrieval_output = self.retrieval_layer(fused_features)
235
+ return retrieval_output
236
+
237
+ class ASRLayer(nn.Module):
238
+ def __init__(self, config):
239
+ super(ASRLayer, self).__init__()
240
+ self.config = config
241
+ self.asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
242
+
243
+ def forward(self, fused_features):
244
+ asr_output = self.asr_layer(fused_features)
245
+ return asr_output
246
+
247
+ class RealtimeASRLayer(nn.Module):
248
+ def __init__(self, config):
249
+ super(RealtimeASRLayer, self).__init__()
250
+ self.config = config
251
+ self.realtime_asr_layer = nn.Linear(config.hidden_size, config.vocab_size)
252
+
253
+ def forward(self, fused_features):
254
+ realtime_asr_output = self.realtime_asr_layer(fused_features)
255
+ return realtime_asr_output
256
+
257
+ class TextOutputLayer(nn.Module):
258
+ def __init__(self, config):
259
+ super(TextOutputLayer, self).__init__()
260
+ self.config = config
261
+ self.text_output_layer = nn.Linear(config.hidden_size, config.vocab_size)
262
+
263
+ def forward(self, fused_features):
264
+ text_output = self.text_output_layer(fused_features)
265
+ return text_output
266
+
267
+ # 主模型定义
268
+ class AutoModel(nn.Module):
269
+ def __init__(self, config):
270
+ super(AutoModel, self).__init__()
271
+ self.config = config
272
+ self.image_encoder = ImageEncoder(config)
273
+ self.text_encoder = TextEncoder(config)
274
+ self.audio_encoder = AudioEncoder(config)
275
+ self.fusion_layer = FusionLayer(config)
276
+ self.vqa_layer = VQALayer(config)
277
+ self.caption_layer = CaptionLayer(config)
278
+ self.retrieval_layer = RetrievalLayer(config)
279
+ self.asr_layer = ASRLayer(config)
280
+ self.realtime_asr_layer = RealtimeASRLayer(config)
281
+ self.text_output_layer = TextOutputLayer(config)
282
+
283
+ def forward(self, image, text, audio):
284
+ image_features = self.image_encoder(image)
285
+ text_features = self.text_encoder(text)
286
+ audio_features = self.audio_encoder(audio)
287
+ fused_features = self.fusion_layer(image_features, text_features, audio_features)
288
+ vqa_output = self.vqa_layer(fused_features)
289
+ caption_output = self.caption_layer(fused_features)
290
+ retrieval_output = self.retrieval_layer(fused_features)
291
+ asr_output = self.asr_layer(fused_features)
292
+ realtime_asr_output = self.realtime_asr_layer(fused_features)
293
+ text_output = self.text_output_layer(fused_features)
294
+ return vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output, text_output
295
+
296
+ # 测试代码
297
+ config = Config()
298
+ model = AutoModel(config)
299
+ image = torch.randn(1, 3, 224, 224)
300
+ text = torch.randn(1, config.max_position_embeddings, config.hidden_size)
301
+ audio = torch.randn(1, config.audio_sample_rate)
302
+ vqa_output, caption_output, retrieval_output, asr_output, realtime_asr_output, text_output = model(image, text, audio)
303
+
304
+ # 输出结果
305
+ print("VQA output shape:", vqa_output.shape)
306
+ print("Caption output shape:", caption_output.shape)
307
+ print("Retrieval output shape:", retrieval_output.shape)
308
+ print("ASR output shape:", asr_output.shape)
309
+ print("Realtime ASR output shape:", realtime_asr_output.shape)
310
+ print("Text output shape:", text_output.shape)
311
+
312
+ # 打印总参数数量
313
+ total_params = sum(p.numel() for p in model.parameters())
314
+ print(f"\n总参数数量: {total_params}")
315
+
316
+
317
+
318
+ # 保存模型权重
319
+ save_path = "save.pth"
320
+ torch.save(model.state_dict(), save_path)
321
+ print(f"模型权重已保存到: {save_path}")
322
+
323
+
324
  ```
325
 
326
  ### Direct Use