Spaces:
Sleeping
Sleeping
File size: 1,076 Bytes
fbebf66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from dataclasses import dataclass
from typing import Dict, Any, List
import torch
import torch.nn as nn
@dataclass
class PerceptionState:
visual_data: torch.Tensor
audio_data: torch.Tensor
text_data: torch.Tensor
context_vector: torch.Tensor
attention_weights: Dict[str, float]
class MultiModalEncoder(nn.Module):
def __init__(self):
super().__init__()
self.visual_encoder = VisualProcessor()
self.audio_encoder = AudioProcessor()
self.text_encoder = TextProcessor()
self.fusion_layer = ModalityFusion()
def forward(self, inputs: Dict[str, torch.Tensor]) -> PerceptionState:
visual_features = self.visual_encoder(inputs.get('visual'))
audio_features = self.audio_encoder(inputs.get('audio'))
text_features = self.text_encoder(inputs.get('text'))
fused_representation = self.fusion_layer(
visual_features,
audio_features,
text_features
)
return self._create_perception_state(fused_representation) |