Spaces:
Sleeping
Sleeping
from typing import Optional | |
import torch | |
from torch import nn | |
from torch.distributions.transforms import TanhTransform | |
class NonegativeParameter(nn.Module): | |
""" | |
Overview: | |
This module will output a non-negative parameter during the forward process. | |
Interfaces: | |
``__init__``, ``forward``, ``set_data``. | |
""" | |
def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8): | |
""" | |
Overview: | |
Initialize the NonegativeParameter object using the given arguments. | |
Arguments: | |
- data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ | |
default value is 0. | |
- requires_grad (:obj:`bool`): Whether this parameter requires grad. | |
- delta (:obj:`Any`): The delta of log function. | |
""" | |
super().__init__() | |
if data is None: | |
data = torch.zeros(1) | |
self.log_data = nn.Parameter(torch.log(data + delta), requires_grad=requires_grad) | |
def forward(self) -> torch.Tensor: | |
""" | |
Overview: | |
Output the non-negative parameter during the forward process. | |
Returns: | |
parameter (:obj:`torch.Tensor`): The generated parameter. | |
""" | |
return torch.exp(self.log_data) | |
def set_data(self, data: torch.Tensor) -> None: | |
""" | |
Overview: | |
Set the value of the non-negative parameter. | |
Arguments: | |
data (:obj:`torch.Tensor`): The new value of the non-negative parameter. | |
""" | |
self.log_data = nn.Parameter(torch.log(data + 1e-8), requires_grad=self.log_data.requires_grad) | |
class TanhParameter(nn.Module): | |
""" | |
Overview: | |
This module will output a tanh parameter during the forward process. | |
Interfaces: | |
``__init__``, ``forward``, ``set_data``. | |
""" | |
def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True): | |
""" | |
Overview: | |
Initialize the TanhParameter object using the given arguments. | |
Arguments: | |
- data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \ | |
default value is 1. | |
- requires_grad (:obj:`bool`): Whether this parameter requires grad. | |
""" | |
super().__init__() | |
if data is None: | |
data = torch.zeros(1) | |
self.transform = TanhTransform(cache_size=1) | |
self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=requires_grad) | |
def forward(self) -> torch.Tensor: | |
""" | |
Overview: | |
Output the tanh parameter during the forward process. | |
Returns: | |
parameter (:obj:`torch.Tensor`): The generated parameter. | |
""" | |
return self.transform(self.data_inv) | |
def set_data(self, data: torch.Tensor) -> None: | |
""" | |
Overview: | |
Set the value of the tanh parameter. | |
Arguments: | |
data (:obj:`torch.Tensor`): The new value of the tanh parameter. | |
""" | |
self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=self.data_inv.requires_grad) | |