File size: 468 Bytes
71026d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn


class ClassificationHead(nn.Module):
    """Classification head for the network."""

    def __init__(self, id_to_gps):
        super().__init__()
        self.id_to_gps = id_to_gps

    def forward(self, x):
        """Forward pass of the network.

        x : Union[torch.Tensor, dict] with the output of the backbone.

        """
        gps = self.id_to_gps(x.argmax(dim=-1))
        return {"label": x, **gps}