Spaces:
Runtime error
Runtime error
File size: 1,162 Bytes
c5ca37a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
def safe_log(z):
return torch.log(z + 1e-7)
def log_sum_exp(value, dim=None, keepdim=False):
"""Numerically stable implementation of the operation
value.exp().sum(dim, keepdim).log()
"""
if dim is not None:
m, _ = torch.max(value, dim=dim, keepdim=True)
value0 = value - m
if keepdim is False:
m = m.squeeze(dim)
return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim))
else:
m = torch.max(value)
sum_exp = torch.sum(torch.exp(value - m))
return m + torch.log(sum_exp)
def generate_grid(zmin, zmax, dz, device, ndim=2):
"""generate a 1- or 2-dimensional grid
Returns: Tensor, int
Tensor: The grid tensor with shape (k^2, 2),
where k=(zmax - zmin)/dz
int: k
"""
if ndim == 2:
x = torch.arange(zmin, zmax, dz)
k = x.size(0)
x1 = x.unsqueeze(1).repeat(1, k).view(-1)
x2 = x.repeat(k)
return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k
elif ndim == 1:
return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device) |