Spaces:
Sleeping
Sleeping
"""Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE""" | |
import torch | |
from torch.nn import functional as F | |
from torch import nn | |
from abc import abstractmethod | |
from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional | |
from ding.utils.type_helper import Tensor | |
class VanillaVAE(nn.Module): | |
""" | |
Overview: | |
Implementation of Vanilla variational autoencoder for action reconstruction. | |
Interfaces: | |
``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \ | |
``forward``, ``loss_function`` . | |
""" | |
def __init__( | |
self, | |
action_shape: int, | |
obs_shape: int, | |
latent_size: int, | |
hidden_dims: List = [256, 256], | |
**kwargs | |
) -> None: | |
super(VanillaVAE, self).__init__() | |
self.action_shape = action_shape | |
self.obs_shape = obs_shape | |
self.latent_size = latent_size | |
self.hidden_dims = hidden_dims | |
# Build Encoder | |
self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU()) | |
self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU()) | |
self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU()) | |
self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size) | |
self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size) | |
# Build Decoder | |
self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU()) | |
self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU()) | |
# TODO(pu): tanh | |
self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh()) | |
# residual prediction | |
self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU()) | |
self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape) | |
self.obs_encoding = None | |
def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]: | |
""" | |
Overview: | |
Encodes the input by passing through the encoder network and returns the latent codes. | |
Arguments: | |
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \ | |
`action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively. | |
Returns: | |
- outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \ | |
``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \ | |
representing latent codes. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. | |
- action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. | |
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``. | |
""" | |
action_encoding = self.encode_action_head(input['action']) | |
obs_encoding = self.encode_obs_head(input['obs']) | |
# obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network | |
input = obs_encoding * action_encoding # TODO(pu): what about add, cat? | |
result = self.encode_common(input) | |
# Split the result into mu and var components | |
# of the latent Gaussian distribution | |
mu = self.encode_mu_head(result) | |
log_var = self.encode_logvar_head(result) | |
return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding} | |
def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]: | |
""" | |
Overview: | |
Maps the given latent action and obs_encoding onto the original action space. | |
Arguments: | |
- z (:obj:`torch.Tensor`): the sampled latent action | |
- obs_encoding (:obj:`torch.Tensor`): observation encoding | |
Returns: | |
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value. | |
ReturnsKeys: | |
- reconstruction_action (:obj:`torch.Tensor`): reconstruction_action. | |
- predition_residual (:obj:`torch.Tensor`): predition_residual. | |
Shapes: | |
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` | |
- obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim`` | |
""" | |
action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded | |
action_obs_decoding = action_decoding * obs_encoding | |
action_obs_decoding_tmp = self.decode_common(action_obs_decoding) | |
reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) | |
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) | |
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) | |
return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} | |
def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]: | |
""" | |
Overview: | |
Maps the given latent action and obs onto the original action space. | |
Using the method self.encode_obs_head(obs) to get the obs_encoding. | |
Arguments: | |
- z (:obj:`torch.Tensor`): the sampled latent action | |
- obs (:obj:`torch.Tensor`): observation | |
Returns: | |
- outputs (:obj:`Dict`): DQN forward outputs, such as q_value. | |
ReturnsKeys: | |
- reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE . | |
- predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE. | |
Shapes: | |
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape`` | |
""" | |
obs_encoding = self.encode_obs_head(obs) | |
# TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh | |
action_decoding = self.decode_action_head(z) | |
action_obs_decoding = action_decoding * obs_encoding | |
action_obs_decoding_tmp = self.decode_common(action_obs_decoding) | |
reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp) | |
predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp) | |
predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp) | |
return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual} | |
def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: | |
""" | |
Overview: | |
Reparameterization trick to sample from N(mu, var) from N(0,1). | |
Arguments: | |
- mu (:obj:`torch.Tensor`): Mean of the latent Gaussian | |
- logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian | |
Shapes: | |
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` | |
- logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size`` | |
""" | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return eps * std + mu | |
def forward(self, input: Dict[str, Tensor], **kwargs) -> dict: | |
""" | |
Overview: | |
Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`. | |
Argumens: | |
- input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \ | |
and `action` (:obj:`torch.Tensor`), representing the observation \ | |
and agent's action respectively. | |
Returns: | |
- outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \ | |
(:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \ | |
``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \ | |
``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`). | |
Shapes: | |
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. | |
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \ | |
where B is batch size and O is ``observation dim``. | |
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
- z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size`` | |
""" | |
encode_output = self.encode(input) | |
z = self.reparameterize(encode_output['mu'], encode_output['log_var']) | |
decode_output = self.decode(z, encode_output['obs_encoding']) | |
return { | |
'recons_action': decode_output['reconstruction_action'], | |
'prediction_residual': decode_output['predition_residual'], | |
'input': input, | |
'mu': encode_output['mu'], | |
'log_var': encode_output['log_var'], | |
'z': z | |
} | |
def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]: | |
""" | |
Overview: | |
Computes the VAE loss function. | |
Arguments: | |
- args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \ | |
``original_action``, ``mu``, ``log_var`` and ``true_residual``. | |
- kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``. | |
Returns: | |
- outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \ | |
``reconstruction_loss``, ``kld_loss``, ``predict_loss``. | |
Shapes: | |
- recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \ | |
and A is ``action dim``. | |
- prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \ | |
and O is ``observation dim``. | |
- original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``. | |
- mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
- log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``. | |
- true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``. | |
""" | |
recons_action = args['recons_action'] | |
prediction_residual = args['prediction_residual'] | |
original_action = args['original_action'] | |
mu = args['mu'] | |
log_var = args['log_var'] | |
true_residual = args['true_residual'] | |
kld_weight = kwargs['kld_weight'] | |
predict_weight = kwargs['predict_weight'] | |
recons_loss = F.mse_loss(recons_action, original_action) | |
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) | |
predict_loss = F.mse_loss(prediction_residual, true_residual) | |
loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss | |
return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss} | |