|
import torch.nn as nn
|
|
from models.networks.utils import UnormGPS
|
|
from torch.nn.functional import tanh, sigmoid, softmax
|
|
|
|
|
|
class AuxHead(nn.Module):
|
|
def __init__(self, aux_data=[], use_tanh=False):
|
|
super().__init__()
|
|
self.aux_data = aux_data
|
|
self.unorm = UnormGPS()
|
|
self.use_tanh = use_tanh
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
if self.use_tanh:
|
|
gps = tanh(x["gps"])
|
|
gps = self.unorm(gps)
|
|
output = {"gps": gps}
|
|
if "land_cover" in self.aux_data:
|
|
output["land_cover"] = softmax(x["land_cover"])
|
|
if "road_index" in self.aux_data:
|
|
output["road_index"] = x["road_index"]
|
|
if "drive_side" in self.aux_data:
|
|
output["drive_side"] = sigmoid(x["drive_side"])
|
|
if "climate" in self.aux_data:
|
|
output["climate"] = softmax(x["climate"])
|
|
if "soil" in self.aux_data:
|
|
output["soil"] = softmax(x["soil"])
|
|
if "dist_sea" in self.aux_data:
|
|
output["dist_sea"] = x["dist_sea"]
|
|
return output
|
|
|