|
import torch |
|
import torch.nn as nn |
|
|
|
from src.core import register |
|
|
|
|
|
__all__ = ['Classification', 'ClassHead'] |
|
|
|
|
|
@register |
|
class Classification(nn.Module): |
|
__inject__ = ['backbone', 'head'] |
|
|
|
def __init__(self, backbone: nn.Module, head: nn.Module=None): |
|
super().__init__() |
|
|
|
self.backbone = backbone |
|
self.head = head |
|
|
|
def forward(self, x): |
|
x = self.backbone(x) |
|
|
|
if self.head is not None: |
|
x = self.head(x) |
|
|
|
return x |
|
|
|
|
|
@register |
|
class ClassHead(nn.Module): |
|
def __init__(self, hidden_dim, num_classes): |
|
super().__init__() |
|
self.pool = nn.AdaptiveAvgPool2d(1) |
|
self.proj = nn.Linear(hidden_dim, num_classes) |
|
|
|
def forward(self, x): |
|
x = x[0] if isinstance(x, (list, tuple)) else x |
|
x = self.pool(x) |
|
x = x.reshape(x.shape[0], -1) |
|
x = self.proj(x) |
|
return x |
|
|