Spaces:
Sleeping
Sleeping
File size: 7,457 Bytes
20239f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
# Modified from https://github.com/robertdvdk/part_detection/blob/main/nets.py
import torch
from torch import Tensor
from timm.models import create_model
from torchvision.models import get_model
from torch.nn import Parameter
from typing import Any
from layers.independent_mlp import IndependentMLPs
# Baseline model, a modified ResNet with reduced downsampling for a spatially larger feature tensor in the last layer
class IndividualLandmarkResNet(torch.nn.Module):
def __init__(self, init_model: torch.nn.Module, num_landmarks: int = 8,
num_classes: int = 200, sl_channels: int = 1024, fl_channels: int = 2048,
use_torchvision_model: bool = False, part_dropout: float = 0.3,
modulation_type: str = "original", modulation_orth: bool = False, gumbel_softmax: bool = False,
gumbel_softmax_temperature: float = 1.0, gumbel_softmax_hard: bool = False,
classifier_type: str = "linear", noise_variance: float = 0.0) -> None:
super().__init__()
self.num_landmarks = num_landmarks
self.num_classes = num_classes
self.noise_variance = noise_variance
self.conv1 = init_model.conv1
self.bn1 = init_model.bn1
if use_torchvision_model:
self.act1 = init_model.relu
else:
self.act1 = init_model.act1
self.maxpool = init_model.maxpool
self.layer1 = init_model.layer1
self.layer2 = init_model.layer2
self.layer3 = init_model.layer3
self.layer4 = init_model.layer4
self.feature_dim = sl_channels + fl_channels
self.fc_landmarks = torch.nn.Conv2d(self.feature_dim, num_landmarks + 1, 1, bias=False)
self.gumbel_softmax = gumbel_softmax
self.gumbel_softmax_temperature = gumbel_softmax_temperature
self.gumbel_softmax_hard = gumbel_softmax_hard
self.modulation_type = modulation_type
if modulation_type == "layer_norm":
self.modulation = torch.nn.LayerNorm([self.feature_dim, self.num_landmarks + 1])
elif modulation_type == "original":
self.modulation = torch.nn.Parameter(torch.ones(1, self.feature_dim, self.num_landmarks + 1))
elif modulation_type == "parallel_mlp":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=True, bias=True)
elif modulation_type == "parallel_mlp_no_bias":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=True, bias=False)
elif modulation_type == "parallel_mlp_no_act":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=False, bias=True)
elif modulation_type == "parallel_mlp_no_act_no_bias":
self.modulation = IndependentMLPs(part_dim=self.num_landmarks + 1, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=False, bias=False)
elif modulation_type == "none":
self.modulation = torch.nn.Identity()
else:
raise ValueError("modulation_type not implemented")
self.modulation_orth = modulation_orth
self.dropout_full_landmarks = torch.nn.Dropout1d(part_dropout)
self.classifier_type = classifier_type
if classifier_type == "independent_mlp":
self.fc_class_landmarks = IndependentMLPs(part_dim=self.num_landmarks, latent_dim=self.feature_dim,
num_lin_layers=1, act_layer=False, out_dim=num_classes,
bias=False, stack_dim=1)
elif classifier_type == "linear":
self.fc_class_landmarks = torch.nn.Linear(in_features=self.feature_dim, out_features=num_classes,
bias=False)
else:
raise ValueError("classifier_type not implemented")
def forward(self, x: Tensor) -> tuple[Any, Any, Any, Any, Parameter, int | Any]:
# Pretrained ResNet part of the model
x = self.conv1(x)
x = self.bn1(x)
x = self.act1(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
l3 = self.layer3(x)
x = self.layer4(l3)
x = torch.nn.functional.interpolate(x, size=(l3.shape[-2], l3.shape[-1]), mode='bilinear', align_corners=False)
x = torch.cat((x, l3), dim=1)
# Compute per landmark attention maps
# (b - a)^2 = b^2 - 2ab + a^2, b = feature maps resnet, a = convolution kernel
batch_size = x.shape[0]
ab = self.fc_landmarks(x)
b_sq = x.pow(2).sum(1, keepdim=True)
b_sq = b_sq.expand(-1, self.num_landmarks + 1, -1, -1).contiguous()
a_sq = self.fc_landmarks.weight.pow(2).sum(1).unsqueeze(1).expand(-1, batch_size, x.shape[-2],
x.shape[-1]).contiguous()
a_sq = a_sq.permute(1, 0, 2, 3).contiguous()
dist = b_sq - 2 * ab + a_sq
maps = -dist
# Softmax so that the attention maps for each pixel add up to 1
if self.gumbel_softmax:
maps = torch.nn.functional.gumbel_softmax(maps, dim=1, tau=self.gumbel_softmax_temperature,
hard=self.gumbel_softmax_hard) # [B, num_landmarks + 1, H, W]
else:
maps = torch.nn.functional.softmax(maps, dim=1) # [B, num_landmarks + 1, H, W]
# Use maps to get weighted average features per landmark
all_features = (maps.unsqueeze(1) * x.unsqueeze(2)).mean(-1).mean(-1).contiguous()
if self.noise_variance > 0.0:
all_features += torch.randn_like(all_features,
device=all_features.device) * x.std().detach() * self.noise_variance
# Modulate the features
if self.modulation_type == "original":
all_features_mod = all_features * self.modulation
else:
all_features_mod = self.modulation(all_features)
# Classification based on the landmark features
scores = self.fc_class_landmarks(
self.dropout_full_landmarks(all_features_mod[..., :-1].permute(0, 2, 1).contiguous())).permute(0, 2,
1).contiguous()
if self.modulation_orth:
return all_features_mod, maps, scores, dist
else:
return all_features, maps, scores, dist
def pdisconet_resnet_torchvision_bb(backbone, num_cls=200, k=8, **kwargs):
base_model = get_model(backbone)
return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
modulation_type="original")
def pdisconet_resnet_timm_bb(backbone, num_cls=200, k=8, output_stride=32, **kwargs):
base_model = create_model(backbone, pretrained=True, output_stride=output_stride)
return IndividualLandmarkResNet(base_model, num_landmarks=k, num_classes=num_cls,
modulation_type="original")
|