Spaces:
Sleeping
Sleeping
# NOTE that this is not my code. | |
# Taken from here: https://github.com/vitchyr/rlkit/blob/master/rlkit/torch/distributions.py | |
import torch | |
from torch.distributions import Distribution, Normal | |
class TanhNormal(Distribution): | |
""" | |
Represent distribution of X where | |
X ~ tanh(Z) | |
Z ~ N(mean, std) | |
Note: this is not very numerically stable. | |
""" | |
def __init__(self, normal_mean, normal_std, epsilon=1e-6): | |
""" | |
:param normal_mean: Mean of the normal distribution | |
:param normal_std: Std of the normal distribution | |
:param epsilon: Numerical stability epsilon when computing log-prob. | |
""" | |
self.normal_mean = normal_mean | |
self.normal_std = normal_std | |
self.normal = Normal(normal_mean, normal_std) | |
self.epsilon = epsilon | |
def sample_n(self, n, return_pre_tanh_value=False): | |
z = self.normal.sample_n(n) | |
if return_pre_tanh_value: | |
return torch.tanh(z), z | |
else: | |
return torch.tanh(z) | |
def log_prob(self, value, pre_tanh_value=None): | |
""" | |
:param value: some value, x | |
:param pre_tanh_value: arctanh(x) | |
:return: | |
""" | |
if pre_tanh_value is None: | |
pre_tanh_value = torch.log( | |
(1+value) / (1-value) | |
) / 2 | |
return self.normal.log_prob(pre_tanh_value) - torch.log( | |
1 - value * value + self.epsilon | |
) | |
def sample(self, return_pretanh_value=False): | |
""" | |
Gradients will and should *not* pass through this operation. | |
See https://github.com/pytorch/pytorch/issues/4620 for discussion. | |
""" | |
z = self.normal.sample().detach() | |
if return_pretanh_value: | |
return torch.tanh(z), z | |
else: | |
return torch.tanh(z) | |
def rsample(self, return_pretanh_value=False): | |
""" | |
Sampling in the reparameterization case. | |
""" | |
z = ( | |
self.normal_mean + | |
self.normal_std * | |
Normal( | |
torch.zeros(self.normal_mean.size()), | |
torch.ones(self.normal_std.size()) | |
).sample() | |
) | |
z.requires_grad_() | |
if return_pretanh_value: | |
return torch.tanh(z), z | |
else: | |
return torch.tanh(z) |