Spaces:
Sleeping
Sleeping
""" | |
Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc. | |
MCMC is adapted from https://github.com/google-research/ibc. | |
""" | |
from typing import Callable, Tuple | |
from functools import wraps | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from abc import ABC, abstractmethod | |
from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY | |
from ding.torch_utils import unsqueeze_repeat | |
from ding.model.wrapper import IModelWrapper | |
from ding.model.common import RegressionHead | |
def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict): | |
""" | |
Overview: | |
Create stochastic optimizer. | |
Arguments: | |
- device (:obj:`str`): Device. | |
- stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config. | |
""" | |
return STOCHASTIC_OPTIMIZER_REGISTRY.build( | |
stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config | |
) | |
def no_ebm_grad(): | |
"""Wrapper that disables energy based model gradients""" | |
def ebm_disable_grad_wrapper(func: Callable): | |
def wrapper(*args, **kwargs): | |
ebm = args[-1] | |
assert isinstance(ebm, (IModelWrapper, nn.Module)),\ | |
'Make sure ebm is the last positional arguments.' | |
ebm.requires_grad_(False) | |
result = func(*args, **kwargs) | |
ebm.requires_grad_(True) | |
return result | |
return wrapper | |
return ebm_disable_grad_wrapper | |
class StochasticOptimizer(ABC): | |
""" | |
Overview: | |
Base class for stochastic optimizers. | |
Interface: | |
``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer`` | |
""" | |
def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
Drawing action samples from the uniform random distribution \ | |
and tiling observations to the same shape as action samples. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observation. | |
- num_samples (:obj:`int`): The number of negative samples. | |
Returns: | |
- tiled_obs (:obj:`torch.Tensor`): Observations tiled. | |
- action (:obj:`torch.Tensor`): Action sampled. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- num_samples (:obj:`int`): :math:`N`. | |
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
Examples: | |
>>> obs = torch.randn(2, 4) | |
>>> opt = StochasticOptimizer() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
>>> tiled_obs, action = opt._sample(obs, 8) | |
""" | |
size = (obs.shape[0], num_samples, self.action_bounds.shape[1]) | |
low, high = self.action_bounds[0, :], self.action_bounds[1, :] | |
action_samples = low + (high - low) * torch.rand(size).to(self.device) | |
tiled_obs = unsqueeze_repeat(obs, num_samples, 1) | |
return tiled_obs, action_samples | |
def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm: nn.Module): | |
""" | |
Overview: | |
Return one action for each batch with highest probability (lowest energy). | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observation. | |
- action_samples (:obj:`torch.Tensor`): Action from uniform distributions. | |
Returns: | |
- best_action_samples (:obj:`torch.Tensor`): Best action. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
Examples: | |
>>> obs = torch.randn(2, 4) | |
>>> action_samples = torch.randn(2, 8, 5) | |
>>> ebm = EBM(4, 5) | |
>>> opt = StochasticOptimizer() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
>>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm) | |
""" | |
# (B, N) | |
energies = ebm.forward(obs, action_samples) | |
probs = F.softmax(-1.0 * energies, dim=-1) | |
# (B, ) | |
best_idxs = probs.argmax(dim=-1) | |
return action_samples[torch.arange(action_samples.size(0)), best_idxs] | |
def set_action_bounds(self, action_bounds: np.ndarray): | |
""" | |
Overview: | |
Set action bounds calculated from the dataset statistics. | |
Arguments: | |
- action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \ | |
where action_bounds[0] is lower bound and action_bounds[1] is upper bound. | |
Returns: | |
- action_bounds (:obj:`torch.Tensor`): Action bounds. | |
Shapes: | |
- action_bounds (:obj:`np.ndarray`): :math:`(2, A)`. | |
- action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`. | |
Examples: | |
>>> opt = StochasticOptimizer() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
""" | |
self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device) | |
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
Create tiled observations and sample counter-negatives for InfoNCE loss. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- tiled_obs (:obj:`torch.Tensor`): Tiled observations. | |
- action (:obj:`torch.Tensor`): Actions. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
.. note:: In the case of derivative-free optimization, this function will simply call _sample. | |
""" | |
raise NotImplementedError | |
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
""" | |
Overview: | |
Optimize for the best action conditioned on the current observation. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- best_action_samples (:obj:`torch.Tensor`): Best actions. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
""" | |
raise NotImplementedError | |
class DFO(StochasticOptimizer): | |
""" | |
Overview: | |
Derivative-Free Optimizer in paper Implicit Behavioral Cloning. | |
https://arxiv.org/abs/2109.00137 | |
Interface: | |
``init``, ``sample``, ``infer`` | |
""" | |
def __init__( | |
self, | |
noise_scale: float = 0.33, | |
noise_shrink: float = 0.5, | |
iters: int = 3, | |
train_samples: int = 8, | |
inference_samples: int = 16384, | |
device: str = 'cpu', | |
): | |
""" | |
Overview: | |
Initialize the Derivative-Free Optimizer | |
Arguments: | |
- noise_scale (:obj:`float`): Initial noise scale. | |
- noise_shrink (:obj:`float`): Noise scale shrink rate. | |
- iters (:obj:`int`): Number of iterations. | |
- train_samples (:obj:`int`): Number of samples for training. | |
- inference_samples (:obj:`int`): Number of samples for inference. | |
- device (:obj:`str`): Device. | |
""" | |
self.action_bounds = None | |
self.noise_scale = noise_scale | |
self.noise_shrink = noise_shrink | |
self.iters = iters | |
self.train_samples = train_samples | |
self.inference_samples = inference_samples | |
self.device = device | |
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
Drawing action samples from the uniform random distribution \ | |
and tiling observations to the same shape as action samples. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- tiled_obs (:obj:`torch.Tensor`): Tiled observation. | |
- action_samples (:obj:`torch.Tensor`): Action samples. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
Examples: | |
>>> obs = torch.randn(2, 4) | |
>>> ebm = EBM(4, 5) | |
>>> opt = DFO() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
>>> tiled_obs, action_samples = opt.sample(obs, ebm) | |
""" | |
return self._sample(obs, self.train_samples) | |
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
""" | |
Overview: | |
Optimize for the best action conditioned on the current observation. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- best_action_samples (:obj:`torch.Tensor`): Actions. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
Examples: | |
>>> obs = torch.randn(2, 4) | |
>>> ebm = EBM(4, 5) | |
>>> opt = DFO() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
>>> best_action_samples = opt.infer(obs, ebm) | |
""" | |
noise_scale = self.noise_scale | |
# (B, N, O), (B, N, A) | |
obs, action_samples = self._sample(obs, self.inference_samples) | |
for i in range(self.iters): | |
# (B, N) | |
energies = ebm.forward(obs, action_samples) | |
probs = F.softmax(-1.0 * energies, dim=-1) | |
# Resample with replacement. | |
idxs = torch.multinomial(probs, self.inference_samples, replacement=True) | |
action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] | |
# Add noise and clip to target bounds. | |
action_samples = action_samples + torch.randn_like(action_samples) * noise_scale | |
action_samples = action_samples.clamp(min=self.action_bounds[0, :], max=self.action_bounds[1, :]) | |
noise_scale *= self.noise_shrink | |
# Return target with highest probability. | |
return self._get_best_action_sample(obs, action_samples, ebm) | |
class AutoRegressiveDFO(DFO): | |
""" | |
Overview: | |
AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning. | |
https://arxiv.org/abs/2109.00137 | |
Interface: | |
``__init__``, ``infer`` | |
""" | |
def __init__( | |
self, | |
noise_scale: float = 0.33, | |
noise_shrink: float = 0.5, | |
iters: int = 3, | |
train_samples: int = 8, | |
inference_samples: int = 4096, | |
device: str = 'cpu', | |
): | |
""" | |
Overview: | |
Initialize the AutoRegressive Derivative-Free Optimizer | |
Arguments: | |
- noise_scale (:obj:`float`): Initial noise scale. | |
- noise_shrink (:obj:`float`): Noise scale shrink rate. | |
- iters (:obj:`int`): Number of iterations. | |
- train_samples (:obj:`int`): Number of samples for training. | |
- inference_samples (:obj:`int`): Number of samples for inference. | |
- device (:obj:`str`): Device. | |
""" | |
super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device) | |
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
""" | |
Overview: | |
Optimize for the best action conditioned on the current observation. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- best_action_samples (:obj:`torch.Tensor`): Actions. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
Examples: | |
>>> obs = torch.randn(2, 4) | |
>>> ebm = EBM(4, 5) | |
>>> opt = AutoRegressiveDFO() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
>>> best_action_samples = opt.infer(obs, ebm) | |
""" | |
noise_scale = self.noise_scale | |
# (B, N, O), (B, N, A) | |
obs, action_samples = self._sample(obs, self.inference_samples) | |
for i in range(self.iters): | |
# j: action_dim index | |
for j in range(action_samples.shape[-1]): | |
# (B, N) | |
energies = ebm.forward(obs, action_samples)[..., j] | |
probs = F.softmax(-1.0 * energies, dim=-1) | |
# Resample with replacement. | |
idxs = torch.multinomial(probs, self.inference_samples, replacement=True) | |
action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs] | |
# Add noise and clip to target bounds. | |
action_samples[..., j] = action_samples[..., j] + torch.randn_like(action_samples[..., j]) * noise_scale | |
action_samples[..., j] = action_samples[..., j].clamp( | |
min=self.action_bounds[0, j], max=self.action_bounds[1, j] | |
) | |
noise_scale *= self.noise_shrink | |
# (B, N) | |
energies = ebm.forward(obs, action_samples)[..., -1] | |
probs = F.softmax(-1.0 * energies, dim=-1) | |
# (B, ) | |
best_idxs = probs.argmax(dim=-1) | |
return action_samples[torch.arange(action_samples.size(0)), best_idxs] | |
class MCMC(StochasticOptimizer): | |
""" | |
Overview: | |
MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning. | |
https://arxiv.org/abs/2109.00137 | |
Interface: | |
``__init__``, ``sample``, ``infer``, ``grad_penalty`` | |
""" | |
class BaseScheduler(ABC): | |
""" | |
Overview: | |
Base class for learning rate scheduler. | |
Interface: | |
``get_rate`` | |
""" | |
def get_rate(self, index): | |
""" | |
Overview: | |
Abstract method for getting learning rate. | |
""" | |
raise NotImplementedError | |
class ExponentialScheduler: | |
""" | |
Overview: | |
Exponential learning rate schedule for Langevin sampler. | |
Interface: | |
``__init__``, ``get_rate`` | |
""" | |
def __init__(self, init, decay): | |
""" | |
Overview: | |
Initialize the ExponentialScheduler. | |
Arguments: | |
- init (:obj:`float`): Initial learning rate. | |
- decay (:obj:`float`): Decay rate. | |
""" | |
self._decay = decay | |
self._latest_lr = init | |
def get_rate(self, index): | |
""" | |
Overview: | |
Get learning rate. Assumes calling sequentially. | |
Arguments: | |
- index (:obj:`int`): Current iteration. | |
""" | |
del index | |
lr = self._latest_lr | |
self._latest_lr *= self._decay | |
return lr | |
class PolynomialScheduler: | |
""" | |
Overview: | |
Polynomial learning rate schedule for Langevin sampler. | |
Interface: | |
``__init__``, ``get_rate`` | |
""" | |
def __init__(self, init, final, power, num_steps): | |
""" | |
Overview: | |
Initialize the PolynomialScheduler. | |
Arguments: | |
- init (:obj:`float`): Initial learning rate. | |
- final (:obj:`float`): Final learning rate. | |
- power (:obj:`float`): Power of polynomial. | |
- num_steps (:obj:`int`): Number of steps. | |
""" | |
self._init = init | |
self._final = final | |
self._power = power | |
self._num_steps = num_steps | |
def get_rate(self, index): | |
""" | |
Overview: | |
Get learning rate for index. | |
Arguments: | |
- index (:obj:`int`): Current iteration. | |
""" | |
if index == -1: | |
return self._init | |
return ( | |
(self._init - self._final) * ((1 - (float(index) / float(self._num_steps - 1))) ** (self._power)) | |
) + self._final | |
def __init__( | |
self, | |
iters: int = 100, | |
use_langevin_negative_samples: bool = True, | |
train_samples: int = 8, | |
inference_samples: int = 512, | |
stepsize_scheduler: dict = dict( | |
init=0.5, | |
final=1e-5, | |
power=2.0, | |
# num_steps, | |
), | |
optimize_again: bool = True, | |
again_stepsize_scheduler: dict = dict( | |
init=1e-5, | |
final=1e-5, | |
power=2.0, | |
# num_steps, | |
), | |
device: str = 'cpu', | |
# langevin_step | |
noise_scale: float = 0.5, | |
grad_clip=None, | |
delta_action_clip: float = 0.5, | |
add_grad_penalty: bool = True, | |
grad_norm_type: str = 'inf', | |
grad_margin: float = 1.0, | |
grad_loss_weight: float = 1.0, | |
**kwargs, | |
): | |
""" | |
Overview: | |
Initialize the MCMC. | |
Arguments: | |
- iters (:obj:`int`): Number of iterations. | |
- use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler. | |
- train_samples (:obj:`int`): Number of samples for training. | |
- inference_samples (:obj:`int`): Number of samples for inference. | |
- stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler. | |
- optimize_again (:obj:`bool`): Whether to run a second optimization. | |
- again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization. | |
- device (:obj:`str`): Device. | |
- noise_scale (:obj:`float`): Initial noise scale. | |
- grad_clip (:obj:`float`): Gradient clip. | |
- delta_action_clip (:obj:`float`): Action clip. | |
- add_grad_penalty (:obj:`bool`): Whether to add gradient penalty. | |
- grad_norm_type (:obj:`str`): Gradient norm type. | |
- grad_margin (:obj:`float`): Gradient margin. | |
- grad_loss_weight (:obj:`float`): Gradient loss weight. | |
""" | |
self.iters = iters | |
self.use_langevin_negative_samples = use_langevin_negative_samples | |
self.train_samples = train_samples | |
self.inference_samples = inference_samples | |
self.stepsize_scheduler = stepsize_scheduler | |
self.optimize_again = optimize_again | |
self.again_stepsize_scheduler = again_stepsize_scheduler | |
self.device = device | |
self.noise_scale = noise_scale | |
self.grad_clip = grad_clip | |
self.delta_action_clip = delta_action_clip | |
self.add_grad_penalty = add_grad_penalty | |
self.grad_norm_type = grad_norm_type | |
self.grad_margin = grad_margin | |
self.grad_loss_weight = grad_loss_weight | |
def _gradient_wrt_act( | |
obs: torch.Tensor, | |
action: torch.Tensor, | |
ebm: nn.Module, | |
create_graph: bool = False, | |
) -> torch.Tensor: | |
""" | |
Overview: | |
Calculate gradient w.r.t action. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- action (:obj:`torch.Tensor`): Actions. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
- create_graph (:obj:`bool`): Whether to create graph. | |
Returns: | |
- grad (:obj:`torch.Tensor`): Gradient w.r.t action. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- grad (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
""" | |
action.requires_grad_(True) | |
energy = ebm.forward(obs, action).sum() | |
# `create_graph` set to `True` when second order derivative | |
# is needed i.e, d(de/da)/d_param | |
grad = torch.autograd.grad(energy, action, create_graph=create_graph)[0] | |
action.requires_grad_(False) | |
return grad | |
def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
""" | |
Overview: | |
Calculate gradient penalty. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- action (:obj:`torch.Tensor`): Actions. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- loss (:obj:`torch.Tensor`): Gradient penalty. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`. | |
- action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`. | |
- loss (:obj:`torch.Tensor`): :math:`(B, )`. | |
""" | |
if not self.add_grad_penalty: | |
return 0. | |
# (B, N+1, A), this gradient is differentiable w.r.t model parameters | |
de_dact = MCMC._gradient_wrt_act(obs, action, ebm, create_graph=True) | |
def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor: | |
# de_deact: B, N+1, A | |
# return: B, N+1 | |
grad_norm_type_to_ord = { | |
'1': 1, | |
'2': 2, | |
'inf': float('inf'), | |
} | |
ord = grad_norm_type_to_ord[grad_norm_type] | |
return torch.linalg.norm(de_dact, ord, dim=-1) | |
# (B, N+1) | |
grad_norms = compute_grad_norm(self.grad_norm_type, de_dact) | |
grad_norms = grad_norms - self.grad_margin | |
grad_norms = grad_norms.clamp(min=0., max=1e10) | |
grad_norms = grad_norms.pow(2) | |
grad_loss = grad_norms.mean() | |
return grad_loss * self.grad_loss_weight | |
# can not use @torch.no_grad() during the inference | |
# because we need to calculate gradient w.r.t inputs as MCMC updates. | |
def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor: | |
""" | |
Overview: | |
Run one langevin MCMC step. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- action (:obj:`torch.Tensor`): Actions. | |
- stepsize (:obj:`float`): Step size. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- action (:obj:`torch.Tensor`): Actions. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
- stepsize (:obj:`float`): :math:`(B, )`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
""" | |
l_lambda = 1.0 | |
de_dact = MCMC._gradient_wrt_act(obs, action, ebm) | |
if self.grad_clip: | |
de_dact = de_dact.clamp(min=-self.grad_clip, max=self.grad_clip) | |
gradient_scale = 0.5 | |
de_dact = (gradient_scale * l_lambda * de_dact + torch.randn_like(de_dact) * l_lambda * self.noise_scale) | |
delta_action = stepsize * de_dact | |
delta_action_clip = self.delta_action_clip * 0.5 * (self.action_bounds[1] - self.action_bounds[0]) | |
delta_action = delta_action.clamp(min=-delta_action_clip, max=delta_action_clip) | |
action = action - delta_action | |
action = action.clamp(min=self.action_bounds[0], max=self.action_bounds[1]) | |
return action | |
def _langevin_action_given_obs( | |
self, | |
obs: torch.Tensor, | |
action: torch.Tensor, | |
ebm: nn.Module, | |
scheduler: BaseScheduler = None | |
) -> torch.Tensor: | |
""" | |
Overview: | |
Run langevin MCMC for `self.iters` steps. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- action (:obj:`torch.Tensor`): Actions. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
- scheduler (:obj:`BaseScheduler`): Learning rate scheduler. | |
Returns: | |
- action (:obj:`torch.Tensor`): Actions. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
- action (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
""" | |
if not scheduler: | |
self.stepsize_scheduler['num_steps'] = self.iters | |
scheduler = MCMC.PolynomialScheduler(**self.stepsize_scheduler) | |
stepsize = scheduler.get_rate(-1) | |
for i in range(self.iters): | |
action = self._langevin_step(obs, action, stepsize, ebm) | |
stepsize = scheduler.get_rate(i) | |
return action | |
def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
Create tiled observations and sample counter-negatives for InfoNCE loss. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- tiled_obs (:obj:`torch.Tensor`): Tiled observations. | |
- action_samples (:obj:`torch.Tensor`): Action samples. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`. | |
- action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`. | |
Examples: | |
>>> obs = torch.randn(2, 4) | |
>>> ebm = EBM(4, 5) | |
>>> opt = MCMC() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
>>> tiled_obs, action_samples = opt.sample(obs, ebm) | |
""" | |
obs, uniform_action_samples = self._sample(obs, self.train_samples) | |
if not self.use_langevin_negative_samples: | |
return obs, uniform_action_samples | |
langevin_action_samples = self._langevin_action_given_obs(obs, uniform_action_samples, ebm) | |
return obs, langevin_action_samples | |
def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor: | |
""" | |
Overview: | |
Optimize for the best action conditioned on the current observation. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observations. | |
- ebm (:obj:`torch.nn.Module`): Energy based model. | |
Returns: | |
- best_action_samples (:obj:`torch.Tensor`): Actions. | |
Shapes: | |
- obs (:obj:`torch.Tensor`): :math:`(B, O)`. | |
- ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`. | |
- best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`. | |
Examples: | |
>>> obs = torch.randn(2, 4) | |
>>> ebm = EBM(4, 5) | |
>>> opt = MCMC() | |
>>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0)) | |
>>> best_action_samples = opt.infer(obs, ebm) | |
""" | |
# (B, N, O), (B, N, A) | |
obs, uniform_action_samples = self._sample(obs, self.inference_samples) | |
action_samples = self._langevin_action_given_obs( | |
obs, | |
uniform_action_samples, | |
ebm, | |
) | |
# Run a second optimization, a trick for more precise inference | |
if self.optimize_again: | |
self.again_stepsize_scheduler['num_steps'] = self.iters | |
action_samples = self._langevin_action_given_obs( | |
obs, | |
action_samples, | |
ebm, | |
scheduler=MCMC.PolynomialScheduler(**self.again_stepsize_scheduler), | |
) | |
# action_samples: B, N, A | |
return self._get_best_action_sample(obs, action_samples, ebm) | |
class EBM(nn.Module): | |
""" | |
Overview: | |
Energy based model. | |
Interface: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
obs_shape: int, | |
action_shape: int, | |
hidden_size: int = 512, | |
hidden_layer_num: int = 4, | |
**kwargs, | |
): | |
""" | |
Overview: | |
Initialize the EBM. | |
Arguments: | |
- obs_shape (:obj:`int`): Observation shape. | |
- action_shape (:obj:`int`): Action shape. | |
- hidden_size (:obj:`int`): Hidden size. | |
- hidden_layer_num (:obj:`int`): Number of hidden layers. | |
""" | |
super().__init__() | |
input_size = obs_shape + action_shape | |
self.net = nn.Sequential( | |
nn.Linear(input_size, hidden_size), nn.ReLU(), | |
RegressionHead( | |
hidden_size, | |
1, | |
hidden_layer_num, | |
final_tanh=False, | |
) | |
) | |
def forward(self, obs, action): | |
""" | |
Overview: | |
Forward computation graph of EBM. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). | |
- action (:obj:`torch.Tensor`): Action of shape (B, N, A). | |
Returns: | |
- pred (:obj:`torch.Tensor`): Energy of shape (B, N). | |
Examples: | |
>>> obs = torch.randn(2, 3, 4) | |
>>> action = torch.randn(2, 3, 5) | |
>>> ebm = EBM(4, 5) | |
>>> pred = ebm(obs, action) | |
""" | |
x = torch.cat([obs, action], -1) | |
x = self.net(x) | |
return x['pred'] | |
class AutoregressiveEBM(nn.Module): | |
""" | |
Overview: | |
Autoregressive energy based model. | |
Interface: | |
``__init__``, ``forward`` | |
""" | |
def __init__( | |
self, | |
obs_shape: int, | |
action_shape: int, | |
hidden_size: int = 512, | |
hidden_layer_num: int = 4, | |
): | |
""" | |
Overview: | |
Initialize the AutoregressiveEBM. | |
Arguments: | |
- obs_shape (:obj:`int`): Observation shape. | |
- action_shape (:obj:`int`): Action shape. | |
- hidden_size (:obj:`int`): Hidden size. | |
- hidden_layer_num (:obj:`int`): Number of hidden layers. | |
""" | |
super().__init__() | |
self.ebm_list = nn.ModuleList() | |
for i in range(action_shape): | |
self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num)) | |
def forward(self, obs, action): | |
""" | |
Overview: | |
Forward computation graph of AutoregressiveEBM. | |
Arguments: | |
- obs (:obj:`torch.Tensor`): Observation of shape (B, N, O). | |
- action (:obj:`torch.Tensor`): Action of shape (B, N, A). | |
Returns: | |
- pred (:obj:`torch.Tensor`): Energy of shape (B, N, A). | |
Examples: | |
>>> obs = torch.randn(2, 3, 4) | |
>>> action = torch.randn(2, 3, 5) | |
>>> arebm = AutoregressiveEBM(4, 5) | |
>>> pred = arebm(obs, action) | |
""" | |
output_list = [] | |
for i, ebm in enumerate(self.ebm_list): | |
output_list.append(ebm(obs, action[..., :i + 1])) | |
return torch.stack(output_list, axis=-1) | |