Spaces:
Sleeping
Sleeping
from typing import Optional, Tuple, Union, Dict | |
import torch | |
import torch.nn as nn | |
from ding.utils import MODEL_REGISTRY, SequenceType | |
from ding.torch_utils.network.transformer import Attention | |
from ding.torch_utils.network.nn_module import fc_block, build_normalization | |
from ..common import FCEncoder, ConvEncoder | |
class PCTransformer(nn.Module): | |
""" | |
Overview: | |
The transformer block for neural network of algorithms related to Procedure cloning (PC). | |
Interfaces: | |
``__init__``, ``forward``. | |
""" | |
def __init__( | |
self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int, | |
feedforward_hidden: int, n_feedforward: int | |
) -> None: | |
""" | |
Overview: | |
Initialize the procedure cloning transformer model according to corresponding input arguments. | |
Arguments: | |
- cnn_hidden (:obj:`int`): The last channel dimension of CNN encoder, such as 32. | |
- att_hidden (:obj:`int`): The dimension of attention blocks, such as 32. | |
- att_heads (:obj:`int`): The number of heads in attention blocks, such as 4. | |
- drop_p (:obj:`float`): The drop out rate of attention, such as 0.5. | |
- max_T (:obj:`int`): The sequence length of procedure cloning, such as 4. | |
- n_attn (:obj:`int`): The number of attention layers, such as 4. | |
- feedforward_hidden (:obj:`int`):The dimension of feedforward layers, such as 32. | |
- n_feedforward (:obj:`int`): The number of feedforward layers, such as 4. | |
""" | |
super().__init__() | |
self.n_att = n_att | |
self.n_feedforward = n_feedforward | |
self.attention_layer = [] | |
self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att | |
self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) | |
for i in range(n_att - 1): | |
self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) | |
self.att_drop = nn.Dropout(drop_p) | |
self.fc_blocks = [] | |
self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) | |
for i in range(n_feedforward - 1): | |
self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) | |
self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward) | |
self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Overview: | |
The unique execution (forward) method of PCTransformer. | |
Arguments: | |
- x (:obj:`torch.Tensor`): Sequential data of several hidden states. | |
Returns: | |
- output (:obj:`torch.Tensor`): A tensor with the same shape as the input. | |
Examples: | |
>>> model = PCTransformer(128, 128, 8, 0, 16, 2, 128, 2) | |
>>> h = torch.randn((2, 16, 128)) | |
>>> h = model(h) | |
>>> assert h.shape == torch.Size([2, 16, 128]) | |
""" | |
for i in range(self.n_att): | |
x = self.att_drop(self.attention_layer[i](x, self.mask)) | |
x = self.norm_layer[i](x) | |
for i in range(self.n_feedforward): | |
x = self.fc_blocks[i](x) | |
x = self.norm_layer[i + self.n_att](x) | |
return x | |
class ProcedureCloningMCTS(nn.Module): | |
""" | |
Overview: | |
The neural network of algorithms related to Procedure cloning (PC). | |
Interfaces: | |
``__init__``, ``forward``. | |
""" | |
def __init__( | |
self, | |
obs_shape: SequenceType, | |
action_dim: int, | |
cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256], | |
cnn_activation: nn.Module = nn.ReLU(), | |
cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], | |
cnn_stride: SequenceType = [1, 1, 1, 1, 1], | |
cnn_padding: SequenceType = [1, 1, 1, 1, 1], | |
mlp_hidden_list: SequenceType = [256, 256], | |
mlp_activation: nn.Module = nn.ReLU(), | |
att_heads: int = 8, | |
att_hidden: int = 128, | |
n_att: int = 4, | |
n_feedforward: int = 2, | |
feedforward_hidden: int = 256, | |
drop_p: float = 0.5, | |
max_T: int = 17 | |
) -> None: | |
""" | |
Overview: | |
Initialize the MCTS procedure cloning model according to corresponding input arguments. | |
Arguments: | |
- obs_shape (:obj:`SequenceType`): Observation space shape, such as [4, 84, 84]. | |
- action_dim (:obj:`int`): Action space shape, such as 6. | |
- cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as\ | |
[128, 128, 256, 256, 256]. | |
- cnn_activation (:obj:`nn.Module`): The activation function for cnn blocks, such as ``nn.ReLU()``. | |
- cnn_kernel_size (:obj:`SequenceType`): The kernel size for each cnn block, such as [3, 3, 3, 3, 3]. | |
- cnn_stride (:obj:`SequenceType`): The stride for each cnn block, such as [1, 1, 1, 1, 1]. | |
- cnn_padding (:obj:`SequenceType`): The padding for each cnn block, such as [1, 1, 1, 1, 1]. | |
- mlp_hidden_list (:obj:`SequenceType`): The last dim for this must match the last dim of \ | |
``cnn_hidden_list``, such as [256, 256]. | |
- mlp_activation (:obj:`nn.Module`): The activation function for mlp layers, such as ``nn.ReLU()``. | |
- att_heads (:obj:`int`): The number of attention heads in transformer, such as 8. | |
- att_hidden (:obj:`int`): The number of attention dimension in transformer, such as 128. | |
- n_att (:obj:`int`): The number of attention blocks in transformer, such as 4. | |
- n_feedforward (:obj:`int`): The number of feedforward layers in transformer, such as 2. | |
- drop_p (:obj:`float`): The drop out rate of attention, such as 0.5. | |
- max_T (:obj:`int`): The sequence length of procedure cloning, such as 17. | |
""" | |
super().__init__() | |
# Conv Encoder | |
self.embed_state = ConvEncoder( | |
obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding | |
) | |
self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation) | |
self.cnn_hidden_list = cnn_hidden_list | |
assert cnn_hidden_list[-1] == mlp_hidden_list[-1] | |
layers = [] | |
for i in range(n_att): | |
if i == 0: | |
layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) | |
else: | |
layers.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p))) | |
layers.append(build_normalization('LN')(att_hidden)) | |
for i in range(n_feedforward): | |
if i == 0: | |
layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU())) | |
else: | |
layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU())) | |
self.layernorm2 = build_normalization('LN')(feedforward_hidden) | |
self.transformer = PCTransformer( | |
cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward | |
) | |
self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1]) | |
self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim) | |
def forward(self, states: torch.Tensor, goals: torch.Tensor, | |
actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
ProcedureCloningMCTS forward computation graph, input states tensor and goals tensor, \ | |
calculate the predicted states and actions. | |
Arguments: | |
- states (:obj:`torch.Tensor`): The observation of current time. | |
- goals (:obj:`torch.Tensor`): The target observation after a period. | |
- actions (:obj:`torch.Tensor`): The actions executed during the period. | |
Returns: | |
- outputs (:obj:`Tuple[torch.Tensor, torch.Tensor]`): Predicted states and actions. | |
Examples: | |
>>> inputs = { \ | |
'states': torch.randn(2, 3, 64, 64), \ | |
'goals': torch.randn(2, 3, 64, 64), \ | |
'actions': torch.randn(2, 15, 9) \ | |
} | |
>>> model = ProcedureCloningMCTS(obs_shape=(3, 64, 64), action_dim=9) | |
>>> goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) | |
>>> assert goal_preds.shape == (2, 256) | |
>>> assert action_preds.shape == (2, 16, 9) | |
""" | |
B, T, _ = actions.shape | |
# shape: (B, h_dim) | |
state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1]) | |
goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1]) | |
# shape: (B, context_len, h_dim) | |
actions_embeddings = self.embed_action(actions) | |
h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1) | |
h = self.transformer(h) | |
h = h.reshape(B, T + 2, self.cnn_hidden_list[-1]) | |
goal_preds = self.predict_goal(h[:, 0, :]) | |
action_preds = self.predict_action(h[:, 1:, :]) | |
return goal_preds, action_preds | |
class BFSConvEncoder(nn.Module): | |
""" | |
Overview: | |
The ``BFSConvolution Encoder`` used to encode raw 3-dim observations. And output a feature map with the | |
same height and width as input. Interfaces: ``__init__``, ``forward``. | |
""" | |
def __init__( | |
self, | |
obs_shape: SequenceType, | |
hidden_size_list: SequenceType = [32, 64, 64, 128], | |
activation: Optional[nn.Module] = nn.ReLU(), | |
kernel_size: SequenceType = [8, 4, 3], | |
stride: SequenceType = [4, 2, 1], | |
padding: Optional[SequenceType] = None, | |
) -> None: | |
""" | |
Overview: | |
Init the ``BFSConvolution Encoder`` according to the provided arguments. | |
Arguments: | |
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``. | |
- hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent conv layers \ | |
and the final dense layer. | |
- activation (:obj:`nn.Module`): Type of activation to use in the conv ``layers`` and ``ResBlock``. \ | |
Default is ``nn.ReLU()``. | |
- kernel_size (:obj:`SequenceType`): Sequence of ``kernel_size`` of subsequent conv layers. | |
- stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers. | |
- padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \ | |
See ``nn.Conv2d`` for more details. Default is ``None``. | |
""" | |
super(BFSConvEncoder, self).__init__() | |
self.obs_shape = obs_shape | |
self.act = activation | |
self.hidden_size_list = hidden_size_list | |
if padding is None: | |
padding = [0 for _ in range(len(kernel_size))] | |
layers = [] | |
input_size = obs_shape[0] # in_channel | |
for i in range(len(kernel_size)): | |
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i])) | |
layers.append(self.act) | |
input_size = hidden_size_list[i] | |
layers = layers[:-1] | |
self.main = nn.Sequential(*layers) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Overview: | |
Return output tensor of the env observation. | |
Arguments: | |
- x (:obj:`torch.Tensor`): Env raw observation. | |
Returns: | |
- outputs (:obj:`torch.Tensor`): Output embedding tensor. | |
Examples: | |
>>> model = BFSConvEncoder([3, 16, 16], [32, 32, 4], kernel_size=[3, 3, 3], stride=[1, 1, 1]\ | |
, padding=[1, 1, 1]) | |
>>> inputs = torch.randn(3, 16, 16).unsqueeze(0) | |
>>> outputs = model(inputs) | |
>>> assert outputs['logit'].shape == torch.Size([4, 16, 16]) | |
""" | |
return self.main(x) | |
class ProcedureCloningBFS(nn.Module): | |
""" | |
Overview: | |
The neural network introduced in procedure cloning (PC) to process 3-dim observations.\ | |
Given an input, this model will perform several 3x3 convolutions and output a feature map with \ | |
the same height and width of input. The channel number of output will be the ``action_shape``. | |
Interfaces: | |
``__init__``, ``forward``. | |
""" | |
def __init__( | |
self, | |
obs_shape: SequenceType, | |
action_shape: int, | |
encoder_hidden_size_list: SequenceType = [128, 128, 256, 256], | |
): | |
""" | |
Overview: | |
Init the ``BFSConvolution Encoder`` according to the provided arguments. | |
Arguments: | |
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``,\ | |
such as [4, 84, 84]. | |
- action_dim (:obj:`int`): Action space shape, such as 6. | |
- cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as [128, 128, 256, 256]. | |
""" | |
super().__init__() | |
num_layers = len(encoder_hidden_size_list) | |
kernel_sizes = (3, ) * (num_layers + 1) | |
stride_sizes = (1, ) * (num_layers + 1) | |
padding_sizes = (1, ) * (num_layers + 1) | |
# The output channel equals to action_shape + 1 | |
encoder_hidden_size_list.append(action_shape + 1) | |
self._encoder = BFSConvEncoder( | |
obs_shape=obs_shape, | |
hidden_size_list=encoder_hidden_size_list, | |
kernel_size=kernel_sizes, | |
stride=stride_sizes, | |
padding=padding_sizes, | |
) | |
def forward(self, x: torch.Tensor) -> Dict: | |
""" | |
Overview: | |
The computation graph. Given a 3-dim observation, this function will return a tensor with the same \ | |
height and width. The channel number of output will be the ``action_shape``. | |
Arguments: | |
- x (:obj:`torch.Tensor`): The input observation tensor data. | |
Returns: | |
- outputs (:obj:`Dict`): The output dict of model's forward computation graph, \ | |
only contains a single key ``logit``. | |
Examples: | |
>>> model = ProcedureCloningBFS([3, 16, 16], 4) | |
>>> inputs = torch.randn(16, 16, 3).unsqueeze(0) | |
>>> outputs = model(inputs) | |
>>> assert outputs['logit'].shape == torch.Size([16, 16, 4]) | |
""" | |
x = x.permute(0, 3, 1, 2) | |
x = self._encoder(x) | |
return {'logit': x.permute(0, 2, 3, 1)} | |