File size: 3,719 Bytes
db6ee6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#  -------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  -------------------------------------------------------------------------------------------

from typing import Callable, Optional

import torch
import torch.nn as nn


class MLP(nn.Module):
    """
    Fully connected layers to map between image embeddings and projection space where pairs of images are compared.

    :param input_dim: Input embedding feature size
    :param hidden_dim: Hidden layer size in MLP
    :param output_dim: Output projection size
    :param use_1x1_convs: Use 1x1 conv kernels instead of 2D linear transformations for speed and memory efficiency.
    """

    def __init__(self,
                 input_dim: int,
                 output_dim: int,
                 hidden_dim: Optional[int] = None,
                 use_1x1_convs: bool = False) -> None:
        super().__init__()

        if use_1x1_convs:
            linear_proj_1_args = {'in_channels': input_dim, 'out_channels': hidden_dim, 'kernel_size': 1, 'bias': False}
            linear_proj_2_args = {'in_channels': hidden_dim, 'out_channels': output_dim, 'kernel_size': 1, 'bias': True}
            normalisation_layer: Callable = nn.BatchNorm2d
            projection_layer: Callable = nn.Conv2d
        else:
            linear_proj_1_args = {'in_features': input_dim, 'out_features': hidden_dim, 'bias': False}
            linear_proj_2_args = {'in_features': hidden_dim, 'out_features': output_dim, 'bias': True}
            normalisation_layer = nn.BatchNorm1d
            projection_layer = nn.Linear

        self.output_dim = output_dim
        self.input_dim = input_dim
        if hidden_dim is not None:
            self.model = nn.Sequential(
                projection_layer(**linear_proj_1_args),
                normalisation_layer(hidden_dim),
                nn.ReLU(inplace=True),
                projection_layer(**linear_proj_2_args))
        else:
            self.model = nn.Linear(input_dim, output_dim)  # type: ignore

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """forward pass of the multi-layer perceptron"""
        x = self.model(x)
        return x


class MultiTaskModel(nn.Module):
    """Torch module for multi-task classification heads. We create a separate classification head
    for each task and perform a forward pass on each head independently in forward(). Classification
    heads are instances of `MLP`.

    :param input_dim: Number of dimensions of the input feature map.
    :param classifier_hidden_dim: Number of dimensions of hidden features in the MLP.
    :param num_classes: Number of output classes per task.
    :param num_tasks: Number of classification tasks or heads required.
    """

    def __init__(self, input_dim: int, classifier_hidden_dim: Optional[int], num_classes: int, num_tasks: int):

        super().__init__()

        self.num_classes = num_classes
        self.num_tasks = num_tasks

        for task in range(num_tasks):
            setattr(self, "fc_" + str(task), MLP(input_dim, output_dim=num_classes, hidden_dim=classifier_hidden_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Returns [batch_size, num_tasks, num_classes] tensor of logits."""
        batch_size = x.shape[0]
        out = torch.zeros((batch_size, self.num_classes, self.num_tasks), dtype=x.dtype, device=x.device)
        for task in range(self.num_tasks):
            classifier = getattr(self, "fc_" + str(task))
            out[:, :, task] = classifier(x)
        return out