Spaces:
Sleeping
Sleeping
""" | |
this extremely minimal Decision Transformer model is based on | |
the following causal transformer (GPT) implementation: | |
Misha Laskin's tweet: | |
https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA | |
and its corresponding notebook: | |
https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing | |
** the above colab notebook has a bug while applying masked_fill | |
which is fixed in the following code | |
""" | |
import math | |
from typing import Union, Optional, Tuple | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ding.utils import SequenceType | |
class MaskedCausalAttention(nn.Module): | |
""" | |
Overview: | |
The implementation of masked causal attention in decision transformer. The input of this module is a sequence \ | |
of several tokens. For the calculated hidden embedding for the i-th token, it is only related the 0 to i-1 \ | |
input tokens by applying a mask to the attention map. Thus, this module is called masked-causal attention. | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: | |
""" | |
Overview: | |
Initialize the MaskedCausalAttention Model according to input arguments. | |
Arguments: | |
- h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. | |
- max_T (:obj:`int`): The max context length of the attention, such as 6. | |
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. | |
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. | |
""" | |
super().__init__() | |
self.n_heads = n_heads | |
self.max_T = max_T | |
self.q_net = nn.Linear(h_dim, h_dim) | |
self.k_net = nn.Linear(h_dim, h_dim) | |
self.v_net = nn.Linear(h_dim, h_dim) | |
self.proj_net = nn.Linear(h_dim, h_dim) | |
self.att_drop = nn.Dropout(drop_p) | |
self.proj_drop = nn.Dropout(drop_p) | |
ones = torch.ones((max_T, max_T)) | |
mask = torch.tril(ones).view(1, 1, max_T, max_T) | |
# register buffer makes sure mask does not get updated | |
# during backpropagation | |
self.register_buffer('mask', mask) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Overview: | |
MaskedCausalAttention forward computation graph, input a sequence tensor \ | |
and return a tensor with the same shape. | |
Arguments: | |
- x (:obj:`torch.Tensor`): The input tensor. | |
Returns: | |
- out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. | |
Examples: | |
>>> inputs = torch.randn(2, 4, 64) | |
>>> model = MaskedCausalAttention(64, 5, 4, 0.1) | |
>>> outputs = model(inputs) | |
>>> assert outputs.shape == torch.Size([2, 4, 64]) | |
""" | |
B, T, C = x.shape # batch size, seq length, h_dim * n_heads | |
N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim | |
# rearrange q, k, v as (B, N, T, D) | |
q = self.q_net(x).view(B, T, N, D).transpose(1, 2) | |
k = self.k_net(x).view(B, T, N, D).transpose(1, 2) | |
v = self.v_net(x).view(B, T, N, D).transpose(1, 2) | |
# weights (B, N, T, T) | |
weights = q @ k.transpose(2, 3) / math.sqrt(D) | |
# causal mask applied to weights | |
weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf')) | |
# normalize weights, all -inf -> 0 after softmax | |
normalized_weights = F.softmax(weights, dim=-1) | |
# attention (B, N, T, D) | |
attention = self.att_drop(normalized_weights @ v) | |
# gather heads and project (B, N, T, D) -> (B, T, N*D) | |
attention = attention.transpose(1, 2).contiguous().view(B, T, N * D) | |
out = self.proj_drop(self.proj_net(attention)) | |
return out | |
class Block(nn.Module): | |
""" | |
Overview: | |
The implementation of a transformer block in decision transformer. | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None: | |
""" | |
Overview: | |
Initialize the Block Model according to input arguments. | |
Arguments: | |
- h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. | |
- max_T (:obj:`int`): The max context length of the attention, such as 6. | |
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. | |
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. | |
""" | |
super().__init__() | |
self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p) | |
self.mlp = nn.Sequential( | |
nn.Linear(h_dim, 4 * h_dim), | |
nn.GELU(), | |
nn.Linear(4 * h_dim, h_dim), | |
nn.Dropout(drop_p), | |
) | |
self.ln1 = nn.LayerNorm(h_dim) | |
self.ln2 = nn.LayerNorm(h_dim) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
""" | |
Overview: | |
Forward computation graph of the decision transformer block, input a sequence tensor \ | |
and return a tensor with the same shape. | |
Arguments: | |
- x (:obj:`torch.Tensor`): The input tensor. | |
Returns: | |
- output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input. | |
Examples: | |
>>> inputs = torch.randn(2, 4, 64) | |
>>> model = Block(64, 5, 4, 0.1) | |
>>> outputs = model(inputs) | |
>>> outputs.shape == torch.Size([2, 4, 64]) | |
""" | |
# Attention -> LayerNorm -> MLP -> LayerNorm | |
x = x + self.attention(x) # residual | |
x = self.ln1(x) | |
x = x + self.mlp(x) # residual | |
x = self.ln2(x) | |
# x = x + self.attention(self.ln1(x)) | |
# x = x + self.mlp(self.ln2(x)) | |
return x | |
class DecisionTransformer(nn.Module): | |
""" | |
Overview: | |
The implementation of decision transformer. | |
Interfaces: | |
``__init__``, ``forward``, ``configure_optimizers`` | |
""" | |
def __init__( | |
self, | |
state_dim: Union[int, SequenceType], | |
act_dim: int, | |
n_blocks: int, | |
h_dim: int, | |
context_len: int, | |
n_heads: int, | |
drop_p: float, | |
max_timestep: int = 4096, | |
state_encoder: Optional[nn.Module] = None, | |
continuous: bool = False | |
): | |
""" | |
Overview: | |
Initialize the DecisionTransformer Model according to input arguments. | |
Arguments: | |
- obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128 or (4, 84, 84). | |
- act_dim (:obj:`int`): The dimension of actions, such as 6. | |
- n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3. | |
- h_dim (:obj:`int`): The dimension of the hidden layers, such as 128. | |
- context_len (:obj:`int`): The max context length of the attention, such as 6. | |
- n_heads (:obj:`int`): The number of heads in calculating attention, such as 8. | |
- drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1. | |
- max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096. | |
- state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \ | |
None, the raw state will be pushed into the transformer. | |
- continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``. | |
""" | |
super().__init__() | |
self.state_dim = state_dim | |
self.act_dim = act_dim | |
self.h_dim = h_dim | |
# transformer blocks | |
input_seq_len = 3 * context_len | |
# projection heads (project to embedding) | |
self.embed_ln = nn.LayerNorm(h_dim) | |
self.embed_timestep = nn.Embedding(max_timestep, h_dim) | |
self.drop = nn.Dropout(drop_p) | |
self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim)) | |
self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim)) | |
if state_encoder is None: | |
self.state_encoder = None | |
blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)] | |
self.embed_rtg = torch.nn.Linear(1, h_dim) | |
self.embed_state = torch.nn.Linear(state_dim, h_dim) | |
self.predict_rtg = torch.nn.Linear(h_dim, 1) | |
self.predict_state = torch.nn.Linear(h_dim, state_dim) | |
if continuous: | |
# continuous actions | |
self.embed_action = torch.nn.Linear(act_dim, h_dim) | |
use_action_tanh = True # True for continuous actions | |
else: | |
# discrete actions | |
self.embed_action = torch.nn.Embedding(act_dim, h_dim) | |
use_action_tanh = False # False for discrete actions | |
self.predict_action = nn.Sequential( | |
*([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else [])) | |
) | |
else: | |
blocks = [Block(h_dim, input_seq_len + 1, n_heads, drop_p) for _ in range(n_blocks)] | |
self.state_encoder = state_encoder | |
self.embed_rtg = nn.Sequential(nn.Linear(1, h_dim), nn.Tanh()) | |
self.head = nn.Linear(h_dim, act_dim, bias=False) | |
self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh()) | |
self.transformer = nn.Sequential(*blocks) | |
def forward( | |
self, | |
timesteps: torch.Tensor, | |
states: torch.Tensor, | |
actions: torch.Tensor, | |
returns_to_go: torch.Tensor, | |
tar: Optional[int] = None | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
Forward computation graph of the decision transformer, input a sequence tensor \ | |
and return a tensor with the same shape. | |
Arguments: | |
- timesteps (:obj:`torch.Tensor`): The timestep for input sequence. | |
- states (:obj:`torch.Tensor`): The sequence of states. | |
- actions (:obj:`torch.Tensor`): The sequence of actions. | |
- returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go. | |
- tar (:obj:`Optional[int]`): Whether to predict action, regardless of index. | |
Returns: | |
- output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \ | |
they are correspondingly the predicted states, predicted actions and predicted return-to-go. | |
Examples: | |
>>> B, T = 4, 6 | |
>>> state_dim = 3 | |
>>> act_dim = 2 | |
>>> DT_model = DecisionTransformer(\ | |
state_dim=state_dim,\ | |
act_dim=act_dim,\ | |
n_blocks=3,\ | |
h_dim=8,\ | |
context_len=T,\ | |
n_heads=2,\ | |
drop_p=0.1,\ | |
) | |
>>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T | |
>>> states = torch.randn([B, T, state_dim]) # B x T x state_dim | |
>>> actions = torch.randint(0, act_dim, [B, T, 1]) | |
>>> action_target = torch.randint(0, act_dim, [B, T, 1]) | |
>>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float() | |
>>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T | |
>>> actions = actions.squeeze(-1) | |
>>> state_preds, action_preds, return_preds = DT_model.forward(\ | |
timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\ | |
) | |
>>> assert state_preds.shape == torch.Size([B, T, state_dim]) | |
>>> assert return_preds.shape == torch.Size([B, T, 1]) | |
>>> assert action_preds.shape == torch.Size([B, T, act_dim]) | |
""" | |
B, T = states.shape[0], states.shape[1] | |
if self.state_encoder is None: | |
time_embeddings = self.embed_timestep(timesteps) | |
# time embeddings are treated similar to positional embeddings | |
state_embeddings = self.embed_state(states) + time_embeddings | |
action_embeddings = self.embed_action(actions) + time_embeddings | |
returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings | |
# stack rtg, states and actions and reshape sequence as | |
# (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...) | |
t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings), | |
dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim) | |
h = self.embed_ln(t_p) | |
# transformer and prediction | |
h = self.transformer(h) | |
# get h reshaped such that its size = (B x 3 x T x h_dim) and | |
# h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t | |
# h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t | |
# h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t | |
# that is, for each timestep (t) we have 3 output embeddings from the transformer, | |
# each conditioned on all previous timesteps plus | |
# the 3 input variables at that timestep (r_t, s_t, a_t) in sequence. | |
h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3) | |
return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a | |
state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a | |
action_preds = self.predict_action(h[:, 1]) # predict action given r, s | |
else: | |
state_embeddings = self.state_encoder( | |
states.reshape(-1, *self.state_dim).type(torch.float32).contiguous() | |
) # (batch * block_size, h_dim) | |
state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim) | |
returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32)) | |
action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim) | |
token_embeddings = torch.zeros( | |
(B, T * 3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device | |
) | |
token_embeddings[:, ::3, :] = returns_embeddings | |
token_embeddings[:, 1::3, :] = state_embeddings | |
token_embeddings[:, 2::3, :] = action_embeddings[:, -T + int(tar is None):, :] | |
all_global_pos_emb = torch.repeat_interleave( | |
self.global_pos_emb, B, dim=0 | |
) # batch_size, traj_length, h_dim | |
position_embeddings = torch.gather( | |
all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1) | |
) + self.pos_emb[:, :token_embeddings.shape[1], :] | |
t_p = token_embeddings + position_embeddings | |
h = self.drop(t_p) | |
h = self.transformer(h) | |
h = self.embed_ln(h) | |
logits = self.head(h) | |
return_preds = None | |
state_preds = None | |
action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings | |
return state_preds, action_preds, return_preds | |
def configure_optimizers( | |
self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] = (0.9, 0.95) | |
) -> torch.optim.Optimizer: | |
""" | |
Overview: | |
This function returns an optimizer given the input arguments. \ | |
We are separating out all parameters of the model into two buckets: those that will experience \ | |
weight decay for regularization and those that won't (biases, and layernorm/embedding weights). | |
Arguments: | |
- weight_decay (:obj:`float`): The weigh decay of the optimizer. | |
- learning_rate (:obj:`float`): The learning rate of the optimizer. | |
- betas (:obj:`Tuple[float, float]`): The betas for Adam optimizer. | |
Outputs: | |
- optimizer (:obj:`torch.optim.Optimizer`): The desired optimizer. | |
""" | |
# separate out all parameters to those that will and won't experience regularizing weight decay | |
decay = set() | |
no_decay = set() | |
# whitelist_weight_modules = (torch.nn.Linear, ) | |
whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d) | |
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) | |
for mn, m in self.named_modules(): | |
for pn, p in m.named_parameters(): | |
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name | |
if pn.endswith('bias'): | |
# all biases will not be decayed | |
no_decay.add(fpn) | |
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): | |
# weights of whitelist modules will be weight decayed | |
decay.add(fpn) | |
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): | |
# weights of blacklist modules will NOT be weight decayed | |
no_decay.add(fpn) | |
# special case the position embedding parameter in the root GPT module as not decayed | |
no_decay.add('pos_emb') | |
no_decay.add('global_pos_emb') | |
# validate that we considered every parameter | |
param_dict = {pn: p for pn, p in self.named_parameters()} | |
inter_params = decay & no_decay | |
union_params = decay | no_decay | |
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) | |
assert len(param_dict.keys() - union_params) == 0,\ | |
"parameters %s were not separated into either decay/no_decay set!" \ | |
% (str(param_dict.keys() - union_params), ) | |
# create the pytorch optimizer object | |
optim_groups = [ | |
{ | |
"params": [param_dict[pn] for pn in sorted(list(decay))], | |
"weight_decay": weight_decay | |
}, | |
{ | |
"params": [param_dict[pn] for pn in sorted(list(no_decay))], | |
"weight_decay": 0.0 | |
}, | |
] | |
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas) | |
return optimizer | |