Spaces:
Sleeping
Sleeping
from typing import Union | |
import numpy as np | |
import torch | |
class DiscreteSupport(object): | |
def __init__(self, min: int, max: int, delta: float = 1.) -> None: | |
assert min < max | |
self.min = min | |
self.max = max | |
self.range = np.arange(min, max + 1, delta) | |
self.size = len(self.range) | |
self.set_size = len(self.range) | |
self.delta = delta | |
def scalar_transform(x: torch.Tensor, epsilon: float = 0.001, delta: float = 1.) -> torch.Tensor: | |
""" | |
Overview: | |
Transform the original value to the scaled value, i.e. the h(.) function | |
in paper https://arxiv.org/pdf/1805.11593.pdf. | |
Reference: | |
- MuZero: Appendix F: Network Architecture | |
- https://arxiv.org/pdf/1805.11593.pdf (Page-11) Appendix A : Proposition A.2 | |
""" | |
# h(.) function | |
if delta == 1: # for speed up | |
output = torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + epsilon * x | |
else: | |
# delta != 1 | |
output = torch.sign(x) * (torch.sqrt(torch.abs(x / delta) + 1) - 1) + epsilon * x / delta | |
return output | |
def inverse_scalar_transform( | |
logits: torch.Tensor, | |
support_size: int, | |
epsilon: float = 0.001, | |
categorical_distribution: bool = True | |
) -> torch.Tensor: | |
""" | |
Overview: | |
transform the scaled value or its categorical representation to the original value, | |
i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. | |
Reference: | |
- MuZero Appendix F: Network Architecture. | |
- https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2 | |
""" | |
if categorical_distribution: | |
scalar_support = DiscreteSupport(-support_size, support_size, delta=1) | |
value_probs = torch.softmax(logits, dim=1) | |
value_support = torch.from_numpy(scalar_support.range).unsqueeze(0) | |
value_support = value_support.to(device=value_probs.device) | |
value = (value_support * value_probs).sum(1, keepdim=True) | |
else: | |
value = logits | |
# h^(-1)(.) function | |
output = torch.sign(value) * ( | |
((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1 | |
) | |
# TODO(pu): comment this line due to saving time | |
# output[torch.abs(output) < epsilon] = 0. | |
return output | |
class InverseScalarTransform: | |
""" | |
Overview: | |
transform the the scaled value or its categorical representation to the original value, | |
i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. | |
Reference: | |
- MuZero Appendix F: Network Architecture. | |
- https://arxiv.org/pdf/1805.11593.pdf Appendix A: Proposition A.2 | |
""" | |
def __init__( | |
self, | |
support_size: int, | |
device: Union[str, torch.device] = 'cpu', | |
categorical_distribution: bool = True | |
) -> None: | |
scalar_support = DiscreteSupport(-support_size, support_size, delta=1) | |
self.value_support = torch.from_numpy(scalar_support.range).unsqueeze(0) | |
self.value_support = self.value_support.to(device) | |
self.categorical_distribution = categorical_distribution | |
def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor: | |
if self.categorical_distribution: | |
value_probs = torch.softmax(logits, dim=1) | |
value = value_probs.mul_(self.value_support).sum(1, keepdim=True) | |
else: | |
value = logits | |
tmp = ((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) | |
# t * t is faster than t ** 2 | |
output = torch.sign(value) * (tmp * tmp - 1) | |
return output | |
def visit_count_temperature( | |
manual_temperature_decay: bool, fixed_temperature_value: float, | |
threshold_training_steps_for_final_lr_temperature: int, trained_steps: int | |
) -> float: | |
if manual_temperature_decay: | |
if trained_steps < 0.5 * threshold_training_steps_for_final_lr_temperature: | |
return 1.0 | |
elif trained_steps < 0.75 * threshold_training_steps_for_final_lr_temperature: | |
return 0.5 | |
else: | |
return 0.25 | |
else: | |
return fixed_temperature_value | |
def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Overview: | |
We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations. | |
After this transformation, each scalar is represented as the linear combination of its two adjacent supports. | |
Reference: | |
- MuZero paper Appendix F: Network Architecture. | |
""" | |
min = discrete_support.min | |
max = discrete_support.max | |
set_size = discrete_support.set_size | |
delta = discrete_support.delta | |
x.clamp_(min, max) | |
x_low = x.floor() | |
x_high = x.ceil() | |
p_high = x - x_low | |
p_low = 1 - p_high | |
target = torch.zeros(x.shape[0], x.shape[1], set_size).to(x.device) | |
x_high_idx, x_low_idx = x_high - min / delta, x_low - min / delta | |
target.scatter_(2, x_high_idx.long().unsqueeze(-1), p_high.unsqueeze(-1)) | |
target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) | |
return target | |
def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
return -(torch.log_softmax(prediction, dim=1) * target).sum(1) | |