Spaces:
Sleeping
Sleeping
from typing import Union, List, Dict | |
from collections import namedtuple | |
import numpy as np | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType | |
def extract(a, t, x_shape): | |
""" | |
Overview: | |
extract output from a through index t. | |
Arguments: | |
- a (:obj:`torch.Tensor`): input tensor | |
- t (:obj:`torch.Tensor`): index tensor | |
- x_shape (:obj:`torch.Tensor`): shape of x | |
""" | |
b, *_ = t.shape | |
out = a.gather(-1, t) | |
return out.reshape(b, *((1, ) * (len(x_shape) - 1))) | |
def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32): | |
""" | |
Overview: | |
cosine schedule | |
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ | |
Arguments: | |
- timesteps (:obj:`int`): timesteps of diffusion step | |
- s (:obj:`float`): s | |
- dtype (:obj:`torch.dtype`): dtype of beta | |
Return: | |
Tensor of beta [timesteps,], computing by cosine. | |
""" | |
steps = timesteps + 1 | |
x = np.linspace(0, steps, steps) | |
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 | |
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] | |
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) | |
betas_clipped = np.clip(betas, a_min=0, a_max=0.999) | |
return torch.tensor(betas_clipped, dtype=dtype) | |
def apply_conditioning(x, conditions, action_dim): | |
""" | |
Overview: | |
add condition into x | |
Arguments: | |
- x (:obj:`torch.Tensor`): input tensor | |
- conditions (:obj:`dict`): condition dict, key is timestep, value is condition | |
- action_dim (:obj:`int`): action dim | |
""" | |
for t, val in conditions.items(): | |
x[:, t, action_dim:] = val.clone() | |
return x | |
class DiffusionConv1d(nn.Module): | |
""" | |
Overview: | |
Conv1d with activation and normalization for diffusion models. | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
padding: int, | |
activation: nn.Module = None, | |
n_groups: int = 8 | |
) -> None: | |
""" | |
Overview: | |
Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm. | |
And need add 1-dim when compute norm | |
Arguments: | |
- in_channels (:obj:`int`): Number of channels in the input tensor | |
- out_channels (:obj:`int`): Number of channels in the output tensor | |
- kernel_size (:obj:`int`): Size of the convolving kernel | |
- padding (:obj:`int`): Zero-padding added to both sides of the input | |
- activation (:obj:`nn.Module`): the optional activation function | |
""" | |
super().__init__() | |
self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) | |
self.norm = nn.GroupNorm(n_groups, out_channels) | |
self.act = activation | |
def forward(self, inputs) -> torch.Tensor: | |
""" | |
Overview: | |
compute conv1d for inputs. | |
Arguments: | |
- inputs (:obj:`torch.Tensor`): input tensor | |
Return: | |
- out (:obj:`torch.Tensor`): output tensor | |
""" | |
x = self.conv1(inputs) | |
# [batch, channels, horizon] -> [batch, channels, 1, horizon] | |
x = x.unsqueeze(-2) | |
x = self.norm(x) | |
# [batch, channels, 1, horizon] -> [batch, channels, horizon] | |
x = x.squeeze(-2) | |
out = self.act(x) | |
return out | |
class SinusoidalPosEmb(nn.Module): | |
""" | |
Overview: | |
class for computing sin position embeding | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, dim: int) -> None: | |
""" | |
Overview: | |
Initialization of SinusoidalPosEmb class | |
Arguments: | |
- dim (:obj:`int`): dimension of embeding | |
""" | |
super().__init__() | |
self.dim = dim | |
def forward(self, x) -> torch.Tensor: | |
""" | |
Overview: | |
compute sin position embeding | |
Arguments: | |
- x (:obj:`torch.Tensor`): input tensor | |
Return: | |
- emb (:obj:`torch.Tensor`): output tensor | |
""" | |
device = x.device | |
half_dim = self.dim // 2 | |
emb = math.log(10000) / (half_dim - 1) | |
emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
emb = x[:, None] * emb[None, :] | |
emb = torch.cat((emb.sin(), emb.cos()), dim=1) | |
return emb | |
class Residual(nn.Module): | |
""" | |
Overview: | |
Basic Residual block | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, fn): | |
""" | |
Overview: | |
Initialization of Residual class | |
Arguments: | |
- fn (:obj:`nn.Module`): function of residual block | |
""" | |
super().__init__() | |
self.fn = fn | |
def forward(self, x, *arg, **kwargs): | |
""" | |
Overview: | |
compute residual block | |
Arguments: | |
- x (:obj:`torch.Tensor`): input tensor | |
""" | |
return self.fn(x, *arg, **kwargs) + x | |
class LayerNorm(nn.Module): | |
""" | |
Overview: | |
LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, dim, eps=1e-5) -> None: | |
""" | |
Overview: | |
Initialization of LayerNorm class | |
Arguments: | |
- dim (:obj:`int`): dimension of input | |
- eps (:obj:`float`): eps of LayerNorm | |
""" | |
super().__init__() | |
self.eps = eps | |
self.g = nn.Parameter(torch.ones(1, dim, 1)) | |
self.b = nn.Parameter(torch.zeros(1, dim, 1)) | |
def forward(self, x): | |
""" | |
Overview: | |
compute LayerNorm | |
Arguments: | |
- x (:obj:`torch.Tensor`): input tensor | |
""" | |
print('x.shape:', x.shape) | |
var = torch.var(x, dim=1, unbiased=False, keepdim=True) | |
mean = torch.mean(x, dim=1, keepdim=True) | |
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b | |
class PreNorm(nn.Module): | |
""" | |
Overview: | |
PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon] | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, dim, fn) -> None: | |
""" | |
Overview: | |
Initialization of PreNorm class | |
Arguments: | |
- dim (:obj:`int`): dimension of input | |
- fn (:obj:`nn.Module`): function of residual block | |
""" | |
super().__init__() | |
self.fn = fn | |
self.norm = LayerNorm(dim) | |
def forward(self, x): | |
""" | |
Overview: | |
compute PreNorm | |
Arguments: | |
- x (:obj:`torch.Tensor`): input tensor | |
""" | |
x = self.norm(x) | |
return self.fn(x) | |
class LinearAttention(nn.Module): | |
""" | |
Overview: | |
Linear Attention head | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__(self, dim, heads=4, dim_head=32) -> None: | |
""" | |
Overview: | |
Initialization of LinearAttention class | |
Arguments: | |
- dim (:obj:`int`): dimension of input | |
- heads (:obj:`int`): heads of attention | |
- dim_head (:obj:`int`): dim of head | |
""" | |
super().__init__() | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
hidden_dim = dim_head * heads | |
self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False) | |
self.to_out = nn.Conv1d(hidden_dim, dim, 1) | |
def forward(self, x): | |
""" | |
Overview: | |
compute LinearAttention | |
Arguments: | |
- x (:obj:`torch.Tensor`): input tensor | |
""" | |
qkv = self.to_qkv(x).chunk(3, dim=1) | |
q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv) | |
q = q * self.scale | |
k = k.softmax(dim=-1) | |
context = torch.einsum('b h d n, b h e n -> b h d e', k, v) | |
out = torch.einsum('b h d e, b h d n -> b h e n', context, q) | |
out = out.reshape(out.shape[0], -1, out.shape[-1]) | |
return self.to_out(out) | |
class ResidualTemporalBlock(nn.Module): | |
""" | |
Overview: | |
Residual block of temporal | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True | |
) -> None: | |
""" | |
Overview: | |
Initialization of ResidualTemporalBlock class | |
Arguments: | |
- in_channels (:obj:'int'): dim of in_channels | |
- out_channels (:obj:'int'): dim of out_channels | |
- embed_dim (:obj:'int'): dim of embeding layer | |
- kernel_size (:obj:'int'): kernel_size of conv1d | |
- mish (:obj:'bool'): whether use mish as a activate function | |
""" | |
super().__init__() | |
if mish: | |
act = nn.Mish() | |
else: | |
act = nn.SiLU() | |
self.blocks = nn.ModuleList( | |
[ | |
DiffusionConv1d(in_channels, out_channels, kernel_size, kernel_size // 2, act), | |
DiffusionConv1d(out_channels, out_channels, kernel_size, kernel_size // 2, act), | |
] | |
) | |
self.time_mlp = nn.Sequential( | |
act, | |
nn.Linear(embed_dim, out_channels), | |
) | |
self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \ | |
if in_channels != out_channels else nn.Identity() | |
def forward(self, x, t): | |
""" | |
Overview: | |
compute residual block | |
Arguments: | |
- x (:obj:'tensor'): input tensor | |
- t (:obj:'tensor'): time tensor | |
""" | |
out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1) | |
out = self.blocks[1](out) | |
return out + self.residual_conv(x) | |
class DiffusionUNet1d(nn.Module): | |
""" | |
Overview: | |
Diffusion unet for 1d vector data | |
Interfaces: | |
``__init__``, ``forward``, ``get_pred`` | |
""" | |
def __init__( | |
self, | |
transition_dim: int, | |
dim: int = 32, | |
dim_mults: SequenceType = [1, 2, 4, 8], | |
returns_condition: bool = False, | |
condition_dropout: float = 0.1, | |
calc_energy: bool = False, | |
kernel_size: int = 5, | |
attention: bool = False, | |
) -> None: | |
""" | |
Overview: | |
Initialization of DiffusionUNet1d class | |
Arguments: | |
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim | |
- dim (:obj:'int'): dim of layer | |
- dim_mults (:obj:'SequenceType'): mults of dim | |
- returns_condition (:obj:'bool'): whether use return as a condition | |
- condition_dropout (:obj:'float'): dropout of returns condition | |
- calc_energy (:obj:'bool'): whether use calc_energy | |
- kernel_size (:obj:'int'): kernel_size of conv1d | |
- attention (:obj:'bool'): whether use attention | |
""" | |
super().__init__() | |
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] | |
in_out = list(zip(dims[:-1], dims[1:])) | |
if calc_energy: | |
mish = False | |
act = nn.SiLU() | |
else: | |
mish = True | |
act = nn.Mish() | |
self.time_dim = dim | |
self.returns_dim = dim | |
self.time_mlp = nn.Sequential( | |
SinusoidalPosEmb(dim), | |
nn.Linear(dim, dim * 4), | |
act, | |
nn.Linear(dim * 4, dim), | |
) | |
self.returns_condition = returns_condition | |
self.condition_dropout = condition_dropout | |
self.cale_energy = calc_energy | |
if self.returns_condition: | |
self.returns_mlp = nn.Sequential( | |
nn.Linear(1, dim), | |
act, | |
nn.Linear(dim, dim * 4), | |
act, | |
nn.Linear(dim * 4, dim), | |
) | |
self.mask_dist = torch.distributions.Bernoulli(probs=1 - self.condition_dropout) | |
embed_dim = 2 * dim | |
else: | |
embed_dim = dim | |
self.downs = nn.ModuleList([]) | |
self.ups = nn.ModuleList([]) | |
num_resolution = len(in_out) | |
for ind, (dim_in, dim_out) in enumerate(in_out): | |
is_last = ind >= (num_resolution - 1) | |
self.downs.append( | |
nn.ModuleList( | |
[ | |
ResidualTemporalBlock(dim_in, dim_out, embed_dim, kernel_size, mish=mish), | |
ResidualTemporalBlock(dim_out, dim_out, embed_dim, kernel_size, mish=mish), | |
Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(), | |
nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity() | |
] | |
) | |
) | |
mid_dim = dims[-1] | |
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) | |
self.mid_atten = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity() | |
self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish) | |
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): | |
is_last = ind >= (num_resolution - 1) | |
self.ups.append( | |
nn.ModuleList( | |
[ | |
ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim, kernel_size, mish=mish), | |
ResidualTemporalBlock(dim_in, dim_in, embed_dim, kernel_size, mish=mish), | |
Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(), | |
nn.ConvTranspose1d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity() | |
] | |
) | |
) | |
self.final_conv = nn.Sequential( | |
DiffusionConv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, activation=act), | |
nn.Conv1d(dim, transition_dim, 1), | |
) | |
def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False): | |
""" | |
Overview: | |
compute diffusion unet forward | |
Arguments: | |
- x (:obj:'tensor'): noise trajectory | |
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 | |
- time (:obj:'int'): timestep of diffusion step | |
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return | |
- use_dropout (:obj:'bool'): Whether use returns condition mask | |
- force_dropout (:obj:'bool'): Whether use returns condition | |
""" | |
if self.cale_energy: | |
x_inp = x | |
# [batch, horizon, transition ] -> [batch, transition , horizon] | |
x = x.transpose(1, 2) | |
t = self.time_mlp(time) | |
if self.returns_condition: | |
assert returns is not None | |
returns_embed = self.returns_mlp(returns) | |
if use_dropout: | |
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) | |
returns_embed = mask * returns_embed | |
if force_dropout: | |
returns_embed = 0 * returns_embed | |
t = torch.cat([t, returns_embed], dim=-1) | |
h = [] | |
for resnet, resnet2, atten, downsample in self.downs: | |
x = resnet(x, t) | |
x = resnet2(x, t) | |
x = atten(x) | |
h.append(x) | |
x = downsample(x) | |
x = self.mid_block1(x, t) | |
x = self.mid_atten(x) | |
x = self.mid_block2(x, t) | |
for resnet, resnet2, atten, upsample in self.ups: | |
x = torch.cat((x, h.pop()), dim=1) | |
x = resnet(x, t) | |
x = resnet2(x, t) | |
x = atten(x) | |
x = upsample(x) | |
x = self.final_conv(x) | |
# [batch, transition , horizon] -> [batch, horizon, transition ] | |
x = x.transpose(1, 2) | |
if self.cale_energy: | |
# Energy function | |
energy = ((x - x_inp) ** 2).mean() | |
grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True) | |
return grad[0] | |
else: | |
return x | |
def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False): | |
""" | |
Overview: | |
compute diffusion unet forward | |
Arguments: | |
- x (:obj:'tensor'): noise trajectory | |
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 | |
- time (:obj:'int'): timestep of diffusion step | |
- returns (:obj:'tensor'): condition returns of trajectory, returns is normal return | |
- use_dropout (:obj:'bool'): Whether use returns condition mask | |
- force_dropout (:obj:'bool'): Whether use returns condition | |
""" | |
# [batch, horizon, transition ] -> [batch, transition , horizon] | |
x = x.transpose(1, 2) | |
t = self.time_mlp(time) | |
if self.returns_condition: | |
assert returns is not None | |
returns_embed = self.returns_mlp(returns) | |
if use_dropout: | |
mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device) | |
returns_embed = mask * returns_embed | |
if force_dropout: | |
returns_embed = 0 * returns_embed | |
t = torch.cat([t, returns_embed], dim=-1) | |
h = [] | |
for resnet, resnet2, downsample in self.downs: | |
x = resnet(x, t) | |
x = resnet2(x, t) | |
h.append(x) | |
x = downsample(x) | |
x = self.mid_block1(x, t) | |
x = self.mid_block2(x, t) | |
for resnet, resnet2, upsample in self.ups: | |
x = torch.cat((x, h.pop()), dim=1) | |
x = resnet(x, t) | |
x = resnet2(x, t) | |
x = upsample(x) | |
x = self.final_conv(x) | |
# [batch, transition , horizon] -> [batch, horizon, transition ] | |
x = x.transpose(1, 2) | |
return x | |
class TemporalValue(nn.Module): | |
""" | |
Overview: | |
temporal net for value function | |
Interfaces: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
horizon: int, | |
transition_dim: int, | |
dim: int = 32, | |
time_dim: int = None, | |
out_dim: int = 1, | |
kernel_size: int = 5, | |
dim_mults: SequenceType = [1, 2, 4, 8], | |
) -> None: | |
""" | |
Overview: | |
Initialization of TemporalValue class | |
Arguments: | |
- horizon (:obj:'int'): horizon of trajectory | |
- transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim | |
- dim (:obj:'int'): dim of layer | |
- time_dim (:obj:'int'): dim of time | |
- out_dim (:obj:'int'): dim of output | |
- kernel_size (:obj:'int'): kernel_size of conv1d | |
- dim_mults (:obj:'SequenceType'): mults of dim | |
""" | |
super().__init__() | |
dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] | |
in_out = list(zip(dims[:-1], dims[1:])) | |
time_dim = time_dim or dim | |
self.time_mlp = nn.Sequential( | |
SinusoidalPosEmb(dim), | |
nn.Linear(dim, dim * 4), | |
nn.Mish(), | |
nn.Linear(dim * 4, dim), | |
) | |
self.blocks = nn.ModuleList([]) | |
for ind, (dim_in, dim_out) in enumerate(in_out): | |
self.blocks.append( | |
nn.ModuleList( | |
[ | |
ResidualTemporalBlock(dim_in, dim_out, kernel_size=kernel_size, embed_dim=time_dim), | |
ResidualTemporalBlock(dim_out, dim_out, kernel_size=kernel_size, embed_dim=time_dim), | |
nn.Conv1d(dim_out, dim_out, 3, 2, 1) | |
] | |
) | |
) | |
horizon = horizon // 2 | |
mid_dim = dims[-1] | |
mid_dim_2 = mid_dim // 2 | |
mid_dim_3 = mid_dim // 4 | |
self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim_2, kernel_size=kernel_size, embed_dim=time_dim) | |
self.mid_down1 = nn.Conv1d(mid_dim_2, mid_dim_2, 3, 2, 1) | |
horizon = horizon // 2 | |
self.mid_block2 = ResidualTemporalBlock(mid_dim_2, mid_dim_3, kernel_size=kernel_size, embed_dim=time_dim) | |
self.mid_down2 = nn.Conv1d(mid_dim_3, mid_dim_3, 3, 2, 1) | |
horizon = horizon // 2 | |
fc_dim = mid_dim_3 * max(horizon, 1) | |
self.final_block = nn.Sequential( | |
nn.Linear(fc_dim + time_dim, fc_dim // 2), | |
nn.Mish(), | |
nn.Linear(fc_dim // 2, out_dim), | |
) | |
def forward(self, x, cond, time, *args): | |
""" | |
Overview: | |
compute temporal value forward | |
Arguments: | |
- x (:obj:'tensor'): noise trajectory | |
- cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0 | |
- time (:obj:'int'): timestep of diffusion step | |
""" | |
# [batch, horizon, transition ] -> [batch, transition , horizon] | |
x = x.transpose(1, 2) | |
t = self.time_mlp(time) | |
for resnet, resnet2, downsample in self.blocks: | |
x = resnet(x, t) | |
x = resnet2(x, t) | |
x = downsample(x) | |
x = self.mid_block1(x, t) | |
x = self.mid_down1(x) | |
x = self.mid_block2(x, t) | |
x = self.mid_down2(x) | |
x = x.view(len(x), -1) | |
out = self.final_block(torch.cat([x, t], dim=-1)) | |
return out | |