Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from typing import Dict, Any, List | |
import torch | |
import torch.nn as nn | |
class PerceptionState: | |
visual_data: torch.Tensor | |
audio_data: torch.Tensor | |
text_data: torch.Tensor | |
context_vector: torch.Tensor | |
attention_weights: Dict[str, float] | |
class MultiModalEncoder(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.visual_encoder = VisualProcessor() | |
self.audio_encoder = AudioProcessor() | |
self.text_encoder = TextProcessor() | |
self.fusion_layer = ModalityFusion() | |
def forward(self, inputs: Dict[str, torch.Tensor]) -> PerceptionState: | |
visual_features = self.visual_encoder(inputs.get('visual')) | |
audio_features = self.audio_encoder(inputs.get('audio')) | |
text_features = self.text_encoder(inputs.get('text')) | |
fused_representation = self.fusion_layer( | |
visual_features, | |
audio_features, | |
text_features | |
) | |
return self._create_perception_state(fused_representation) |