File size: 1,678 Bytes
71026d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import pandas as pd
import torch
from torch import nn
from models.networks.utils import UnormGPS
class Random(nn.Module):
def __init__(self, num_output):
"""Random"""
super().__init__()
self.num_output = num_output
self.unorm = UnormGPS()
def forward(self, x):
"""Predicts GPS coordinates from an image.
Args:
x: torch.Tensor with features
"""
#x = x["img"]
gps = torch.rand((x.shape[0], self.num_output), device=x.device) * 2 - 1
return {"gps": self.unorm(gps)}
class RandomCoords(nn.Module):
def __init__(self, coords_path: str):
"""Randomly sample from a list of coordinates
Args:
coords_path: str with path to csv file with coordinates
"""
super().__init__()
coordinates = pd.read_csv(coords_path)
longitudes = coordinates["longitude"].values / 180
latitudes = coordinates["latitude"].values / 90
self.unorm = UnormGPS()
del coordinates
self.N = len(longitudes)
assert len(longitudes) == len(latitudes)
self.coordinates = torch.stack(
[torch.tensor(latitudes), torch.tensor(longitudes)],
dim=-1,
)
del longitudes, latitudes
def forward(self, x):
"""Predicts GPS coordinates from an image.
Args:
x: torch.Tensor with features
"""
x = x["img"]
# randomly select a coordinate in the list
n = torch.randint(0, self.N, (x.shape[0],))
return {"gps": self.unorm(self.coordinates[n].to(x.device))}
|