|
import torch
|
|
from models.networks.utils import UnormGPS
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
|
|
|
|
class IdToGPS(nn.Module):
|
|
def __init__(self, id_to_gps: str):
|
|
"""Map index to gps coordinates (indices can be country or city ids)"""
|
|
super().__init__()
|
|
if "quadtree" in id_to_gps:
|
|
self.id_to_gps = torch.load(
|
|
"_".join(id_to_gps.split("_")[:-4] + id_to_gps.split("_")[-3:])
|
|
)
|
|
else:
|
|
self.id_to_gps = torch.load(id_to_gps)
|
|
|
|
|
|
def forward(self, x):
|
|
"""Mapping from country id to gps coordinates
|
|
Args:
|
|
x: torch.Tensor with features
|
|
"""
|
|
|
|
if isinstance(x, dict):
|
|
|
|
labels, x = x["label"], x["img"]
|
|
else:
|
|
|
|
labels = x
|
|
self.id_to_gps = self.id_to_gps.to(labels.device)
|
|
|
|
return {"gps": self.id_to_gps[labels]}
|
|
|