import torch import torch.nn as nn from typing import Any, Type Tensor = Type[torch.Tensor] from .resnet3d import r3d_18 class CNNResNet3DWithLinearClassifier(nn.Module): def __init__(self, src_modalities: dict[str, dict[str, Any]], tgt_modalities: dict[str, dict[str, Any]] ) -> None: """ ... """ super().__init__() self.core = _CNNResNet3DWithLinearClassifier(len(tgt_modalities)) self.src_modalities = src_modalities self.tgt_modalities = tgt_modalities def forward(self, x: dict[str, Tensor], ) -> dict[str, Tensor]: """ x is expected to be a singleton dictionary """ src_k = list(x.keys())[0] x = x[src_k] out = self.core(x) out = {tgt_k: out[:, i] for i, tgt_k in enumerate(self.tgt_modalities)} return out class _CNNResNet3DWithLinearClassifier(nn.Module): def __init__(self, len_tgt_modalities: int, ) -> None: """ ... """ super().__init__() self.cnn = r3d_18() self.cls = nn.Sequential( nn.Dropout(0.5), nn.Linear(256, len_tgt_modalities), ) def forward(self, x: Tensor) -> Tensor: """ ... """ out_emb = self.forward_emb(x) out_cls = self.forward_cls(out_emb) return out_cls def forward_emb(self, x: Tensor) -> Tensor: """ ... """ return self.cnn(x) def forward_cls(self, out_emb: Tensor) -> Tensor: """ ... """ return self.cls(out_emb)