from dataclasses import dataclass
from typing import Dict, Any, List
import torch
import torch.nn as nn

@dataclass
class PerceptionState:
    visual_data: torch.Tensor
    audio_data: torch.Tensor
    text_data: torch.Tensor
    context_vector: torch.Tensor
    attention_weights: Dict[str, float]

class MultiModalEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.visual_encoder = VisualProcessor()
        self.audio_encoder = AudioProcessor()
        self.text_encoder = TextProcessor()
        self.fusion_layer = ModalityFusion()
        
    def forward(self, inputs: Dict[str, torch.Tensor]) -> PerceptionState:
        visual_features = self.visual_encoder(inputs.get('visual'))
        audio_features = self.audio_encoder(inputs.get('audio'))
        text_features = self.text_encoder(inputs.get('text'))
        
        fused_representation = self.fusion_layer(
            visual_features,
            audio_features,
            text_features
        )
        
        return self._create_perception_state(fused_representation)