Spaces:
Sleeping
Sleeping
import torch | |
import treetensor.torch as ttorch | |
from torch.distributions import Normal, Independent | |
class ArgmaxSampler: | |
''' | |
Overview: | |
Argmax sampler, return the index of the maximum value | |
''' | |
def __call__(self, logit: torch.Tensor) -> torch.Tensor: | |
''' | |
Overview: | |
Return the index of the maximum value | |
Arguments: | |
- logit (:obj:`torch.Tensor`): The input tensor | |
Returns: | |
- action (:obj:`torch.Tensor`): The index of the maximum value | |
''' | |
return logit.argmax(dim=-1) | |
class MultinomialSampler: | |
''' | |
Overview: | |
Multinomial sampler, return the index of the sampled value | |
''' | |
def __call__(self, logit: torch.Tensor) -> torch.Tensor: | |
''' | |
Overview: | |
Return the index of the sampled value | |
Arguments: | |
- logit (:obj:`torch.Tensor`): The input tensor | |
Returns: | |
- action (:obj:`torch.Tensor`): The index of the sampled value | |
''' | |
dist = torch.distributions.Categorical(logits=logit) | |
return dist.sample() | |
class MuSampler: | |
''' | |
Overview: | |
Mu sampler, return the mu of the input tensor | |
''' | |
def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: | |
''' | |
Overview: | |
Return the mu of the input tensor | |
Arguments: | |
- logit (:obj:`ttorch.Tensor`): The input tensor | |
Returns: | |
- action (:obj:`torch.Tensor`): The mu of the input tensor | |
''' | |
return logit.mu | |
class ReparameterizationSampler: | |
''' | |
Overview: | |
Reparameterization sampler, return the reparameterized value of the input tensor | |
''' | |
def __call__(self, logit: ttorch.Tensor) -> torch.Tensor: | |
''' | |
Overview: | |
Return the reparameterized value of the input tensor | |
Arguments: | |
- logit (:obj:`ttorch.Tensor`): The input tensor | |
Returns: | |
- action (:obj:`torch.Tensor`): The reparameterized value of the input tensor | |
''' | |
dist = Normal(logit.mu, logit.sigma) | |
dist = Independent(dist, 1) | |
return dist.rsample() | |
class HybridStochasticSampler: | |
''' | |
Overview: | |
Hybrid stochastic sampler, return the sampled action type and the reparameterized action args | |
''' | |
def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: | |
''' | |
Overview: | |
Return the sampled action type and the reparameterized action args | |
Arguments: | |
- logit (:obj:`ttorch.Tensor`): The input tensor | |
Returns: | |
- action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args | |
''' | |
dist = torch.distributions.Categorical(logits=logit.action_type) | |
action_type = dist.sample() | |
dist = Normal(logit.action_args.mu, logit.action_args.sigma) | |
dist = Independent(dist, 1) | |
action_args = dist.rsample() | |
return ttorch.as_tensor({ | |
'action_type': action_type, | |
'action_args': action_args, | |
}) | |
class HybridDeterminsticSampler: | |
''' | |
Overview: | |
Hybrid deterministic sampler, return the argmax action type and the mu action args | |
''' | |
def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor: | |
''' | |
Overview: | |
Return the argmax action type and the mu action args | |
Arguments: | |
- logit (:obj:`ttorch.Tensor`): The input tensor | |
Returns: | |
- action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args | |
''' | |
action_type = logit.action_type.argmax(dim=-1) | |
action_args = logit.action_args.mu | |
return ttorch.as_tensor({ | |
'action_type': action_type, | |
'action_args': action_args, | |
}) | |