nmed2024 / adrd /nn /cnn_resnet3d_with_linear_classifier.py
xf3227's picture
ok
6fc43ab
import torch
import torch.nn as nn
from typing import Any, Type
Tensor = Type[torch.Tensor]
from .resnet3d import r3d_18
class CNNResNet3DWithLinearClassifier(nn.Module):
def __init__(self,
src_modalities: dict[str, dict[str, Any]],
tgt_modalities: dict[str, dict[str, Any]]
) -> None:
""" ... """
super().__init__()
self.core = _CNNResNet3DWithLinearClassifier(len(tgt_modalities))
self.src_modalities = src_modalities
self.tgt_modalities = tgt_modalities
def forward(self,
x: dict[str, Tensor],
) -> dict[str, Tensor]:
""" x is expected to be a singleton dictionary """
src_k = list(x.keys())[0]
x = x[src_k]
out = self.core(x)
out = {tgt_k: out[:, i] for i, tgt_k in enumerate(self.tgt_modalities)}
return out
class _CNNResNet3DWithLinearClassifier(nn.Module):
def __init__(self,
len_tgt_modalities: int,
) -> None:
""" ... """
super().__init__()
self.cnn = r3d_18()
self.cls = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(256, len_tgt_modalities),
)
def forward(self, x: Tensor) -> Tensor:
""" ... """
out_emb = self.forward_emb(x)
out_cls = self.forward_cls(out_emb)
return out_cls
def forward_emb(self, x: Tensor) -> Tensor:
""" ... """
return self.cnn(x)
def forward_cls(self, out_emb: Tensor) -> Tensor:
""" ... """
return self.cls(out_emb)