|
import torch
|
|
import numpy as np
|
|
from torch import nn
|
|
|
|
|
|
class NormGPS(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
"""Normalize latitude longtitude radians to -1, 1."""
|
|
return x / torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)
|
|
|
|
|
|
class UnormGPS(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
"""Unormalize latitude longtitude radians to -1, 1."""
|
|
x = torch.clamp(x, -1, 1)
|
|
return x * torch.Tensor([np.pi * 0.5, np.pi]).unsqueeze(0).to(x.device)
|
|
|