import math import numpy as np import re import torch import torch.nn as nn import torch.nn.functional as F from pytorchvideo.models.x3d import create_x3d_stem from timm.models.vision_transformer import VisionTransformer from timm.models.swin_transformer_v2 import SwinTransformerV2 from . import backbones from . import segmentation from .pooling import create_pool2d_layer, create_pool3d_layer from .sequence import Transformer, DualTransformer, DualTransformerV2 from .tools import change_initial_stride, change_num_input_channels class Net2D(nn.Module): def __init__(self, backbone, pretrained, num_classes, dropout, pool, in_channels=3, change_stride=None, feature_reduction=None, multisample_dropout=False, load_pretrained_backbone=None, freeze_backbone=False, backbone_params={}, pool_layer_params={}): super().__init__() self.backbone, dim_feats = backbones.create_backbone(name=backbone, pretrained=pretrained, **backbone_params) if isinstance(pool, str): self.pool_layer = create_pool2d_layer(name=pool, **pool_layer_params) else: self.pool_layer = nn.Identity() if pool == "catavgmax": dim_feats *= 2 self.msdo = multisample_dropout if in_channels != 3: self.backbone = change_num_input_channels(self.backbone, in_channels) if change_stride: self.backbone = change_initial_stride(self.backbone, tuple(change_stride), in_channels) self.dropout = nn.Dropout(p=dropout) if isinstance(feature_reduction, int): # Use 1D grouped convolution to reduce # of parameters groups = math.gcd(dim_feats, feature_reduction) self.feature_reduction = nn.Conv1d(dim_feats, feature_reduction, groups=groups, kernel_size=1, stride=1, bias=False) dim_feats = feature_reduction self.classifier = nn.Linear(dim_feats, num_classes) if load_pretrained_backbone: # Assumes that model has a `backbone` attribute # Note: if you want to load the entire pretrained model, this is done via the # builder.build_model function print(f"Loading pretrained backbone from {load_pretrained_backbone} ...") weights = torch.load(load_pretrained_backbone, map_location=lambda storage, loc: storage)['state_dict'] weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items()} # Get feature_reduction, if present feat_reduce_weight = {re.sub(r"^feature_reduction.", "", k): v for k, v in weights.items() if "feature_reduction" in k} # Get backbone only weights = {re.sub(r'^backbone.', '', k) : v for k,v in weights.items() if 'backbone' in k} self.backbone.load_state_dict(weights) if len(feat_reduce_weight) > 0: print("Also loading feature reduction layer ...") self.feature_reduction.load_state_dict(feat_reduce_weight) if freeze_backbone: print("Freezing backbone ...") for param in self.backbone.parameters(): param.requires_grad = False def extract_features(self, x): features = self.backbone(x) features = self.pool_layer(features) if isinstance(self.backbone, VisionTransformer): features = features[:, self.backbone.num_prefix_tokens:].mean(dim=1) if isinstance(self.backbone, SwinTransformerV2): features = features.mean(dim=1) if hasattr(self, "feature_reduction"): features = self.feature_reduction(features.unsqueeze(-1)).squeeze(-1) return features def forward(self, x): features = self.extract_features(x) if self.msdo: x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0) else: x = self.classifier(self.dropout(features)) # Important nuance: # For binary classification, the model returns a tensor of shape (N,) # Otherwise, (N,C) return x[:, 0] if self.classifier.out_features == 1 else x class SeqNet2D(Net2D): def forward(self, x): # x.shape = (N, C, Z, H, W) features = torch.stack([self.extract_features(x[:, :, _]) for _ in range(x.size(2))], dim=2) features = features.max(2)[0] if self.msdo: x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0) else: x = self.classifier(self.dropout(features)) # Important nuance: # For binary classification, the model returns a tensor of shape (N,) # Otherwise, (N,C) return x[:, 0] if self.classifier.out_features == 1 else x class TDCNN(nn.Module): def __init__(self, cnn_params, transformer_params, freeze_cnn=False, freeze_transformer=False): super().__init__() self.cnn = Net2D(**cnn_params) del self.cnn.dropout del self.cnn.classifier self.transformer = Transformer(**transformer_params) if freeze_cnn: for param in self.cnn.parameters(): param.requires_grad = False if freeze_transformer: for param in self.transformer.parameters(): param.requires_grad = False def extract_features(self, x): N, C, Z, H, W = x.size() assert N == 1, "For feature extraction, batch size must be 1" features = self.cnn.extract_features(x.squeeze(0).transpose(0, 1)).unsqueeze(0) # features.shape = (1, Z, dim_feats) return self.transformer.extract_features((features, torch.ones((features.size(0), features.size(1))).to(features.device))) def forward(self, x): # BCZHW features = torch.stack([self.cnn.extract_features(x[:, :, i]) for i in range(x.size(2))], dim=1) # B, seq_len, dim_feat return self.transformer((features, torch.ones((features.size(0), features.size(1))).to(features.device))) class Net2DWith3DStem(Net2D): def __init__(self, *args, **kwargs): stem_out_channels = kwargs.pop("stem_out_channels", 24) load_pretrained_stem = kwargs.pop("load_pretrained_stem", None) conv_kernel_size = tuple(kwargs.pop("conv_kernel_size", (5, 3, 3))) conv_stride = tuple(kwargs.pop("conv_stride", (1, 2, 2))) in_channels = kwargs.pop("in_channels", 3) kwargs["in_channels"] = stem_out_channels super().__init__(*args, **kwargs) self.stem_layer = create_x3d_stem(in_channels=in_channels, out_channels=stem_out_channels, conv_kernel_size=conv_kernel_size, conv_stride=conv_stride) if kwargs["pretrained"]: from pytorchvideo.models.hub import x3d_l self.stem_layer.load_state_dict(x3d_l(pretrained=True).blocks[0].state_dict()) if load_pretrained_stem: import re print(f" Loading pretrained stem from {load_pretrained_stem} ...") weights = torch.load(load_pretrained_stem, map_location=lambda storage, loc: storage)['state_dict'] stem_weights = {k.replace("model.backbone.blocks.0.", ""): v for k, v in weights.items() if "backbone.blocks.0" in k} self.stem_layer.load_state_dict(stem_weights) def forward(self, x): x = self.stem_layer(x) x = x.mean(3) features = self.extract_features(x) if self.msdo: x = torch.mean(torch.stack([self.classifier(self.dropout(features)) for _ in range(5)]), dim=0) else: x = self.classifier(self.dropout(features)) # Important nuance: # For binary classification, the model returns a tensor of shape (N,) # Otherwise, (N,C) return x[:, 0] if self.classifier.out_features == 1 else x class Net3D(Net2D): def __init__(self, *args, **kwargs): z_strides = kwargs.pop("z_strides", [1,1,1,1,1]) super().__init__(*args, **kwargs) self.pool_layer = create_pool3d_layer(name=kwargs["pool"], **kwargs.pop("pool_layer_params", {})) class NetSegment2D(nn.Module): """ For now, this class essentially servers as a wrapper for the segmentation model which is mostly defined in the segmentation submodule, similar to the original segmentation_models.pytorch. It may be worth refactoring it in the future, such that you define this as a general class, then select your choice of encoder and decoder. The encoder is pretty much the same across all the segmentation models currently implemented (DeepLabV3+, FPN, Unet). """ def __init__(self, architecture, encoder_name, encoder_params, decoder_params, num_classes, dropout, in_channels, load_pretrained_encoder=None, freeze_encoder=False, deep_supervision=False, pool_layer_params={}, aux_head_params={}): super().__init__() self.segmentation_model = getattr(segmentation, architecture)( encoder_name=encoder_name, encoder_params=encoder_params, dropout=dropout, classes=num_classes, deep_supervision=deep_supervision, in_channels=in_channels, **decoder_params ) if load_pretrained_encoder: # Assumes that model has a `encoder` attribute # Note: if you want to load the entire pretrained model, this is done via the # builder.build_model function print(f"Loading pretrained encoder from {load_pretrained_encoder} ...") weights = torch.load(load_pretrained_encoder, map_location=lambda storage, loc: storage)['state_dict'] weights = {re.sub(r'^model.segmentation_model', '', k) : v for k,v in weights.items()} # Get encoder only weights = {re.sub(r'^encoder.', '', k) : v for k,v in weights.items() if 'backbone' in k} self.segmentation_model.encoder.load_state_dict(weights) if freeze_encoder: print("Freezing encoder ...") for param in self.segmentation_model.encoder.parameters(): param.requires_grad = False def forward(self, x): return self.segmentation_model(x) class NetSegment3D(NetSegment2D): pass