Spaces:
Sleeping
Sleeping
import numpy as np | |
import copy | |
import torch | |
from torch import nn | |
from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts | |
from ding.utils.data import default_collate | |
from ding.model import ConvEncoder | |
from ding.world_model.base_world_model import WorldModel | |
from ding.world_model.model.networks import RSSM, ConvDecoder | |
from ding.torch_utils import to_device | |
from ding.torch_utils.network.dreamer import DenseHead | |
class DREAMERWorldModel(WorldModel, nn.Module): | |
config = dict( | |
pretrain=100, | |
train_freq=2, | |
model=dict( | |
state_size=None, | |
action_size=None, | |
model_lr=1e-4, | |
reward_size=1, | |
hidden_size=200, | |
batch_size=256, | |
max_epochs_since_update=5, | |
dyn_stoch=32, | |
dyn_deter=512, | |
dyn_hidden=512, | |
dyn_input_layers=1, | |
dyn_output_layers=1, | |
dyn_rec_depth=1, | |
dyn_shared=False, | |
dyn_discrete=32, | |
act='SiLU', | |
norm='LayerNorm', | |
grad_heads=['image', 'reward', 'discount'], | |
units=512, | |
reward_layers=2, | |
discount_layers=2, | |
value_layers=2, | |
actor_layers=2, | |
cnn_depth=32, | |
encoder_kernels=[4, 4, 4, 4], | |
decoder_kernels=[4, 4, 4, 4], | |
reward_head='twohot_symlog', | |
kl_lscale=0.1, | |
kl_rscale=0.5, | |
kl_free=1.0, | |
kl_forward=False, | |
pred_discount=True, | |
dyn_mean_act='none', | |
dyn_std_act='sigmoid2', | |
dyn_temp_post=True, | |
dyn_min_std=0.1, | |
dyn_cell='gru_layer_norm', | |
unimix_ratio=0.01, | |
device='cuda' if torch.cuda.is_available() else 'cpu', | |
), | |
) | |
def __init__(self, cfg, env, tb_logger): | |
WorldModel.__init__(self, cfg, env, tb_logger) | |
nn.Module.__init__(self) | |
self.pretrain_flag = True | |
self._cfg = cfg.model | |
#self._cfg.act = getattr(torch.nn, self._cfg.act), | |
#self._cfg.norm = getattr(torch.nn, self._cfg.norm), | |
self._cfg.act = nn.modules.activation.SiLU # nn.SiLU | |
self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm | |
self.state_size = self._cfg.state_size | |
self.action_size = self._cfg.action_size | |
self.reward_size = self._cfg.reward_size | |
self.hidden_size = self._cfg.hidden_size | |
self.batch_size = self._cfg.batch_size | |
self.encoder = ConvEncoder( | |
self.state_size, | |
hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128? | |
activation=torch.nn.SiLU(), | |
kernel_size=self._cfg.encoder_kernels, | |
layer_norm=True | |
) | |
self.embed_size = ( | |
(self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth * | |
2 ** (len(self._cfg.encoder_kernels) - 1) | |
) | |
self.dynamics = RSSM( | |
self._cfg.dyn_stoch, | |
self._cfg.dyn_deter, | |
self._cfg.dyn_hidden, | |
self._cfg.dyn_input_layers, | |
self._cfg.dyn_output_layers, | |
self._cfg.dyn_rec_depth, | |
self._cfg.dyn_shared, | |
self._cfg.dyn_discrete, | |
self._cfg.act, | |
self._cfg.norm, | |
self._cfg.dyn_mean_act, | |
self._cfg.dyn_std_act, | |
self._cfg.dyn_temp_post, | |
self._cfg.dyn_min_std, | |
self._cfg.dyn_cell, | |
self._cfg.unimix_ratio, | |
self._cfg.action_size, | |
self.embed_size, | |
self._cfg.device, | |
) | |
self.heads = nn.ModuleDict() | |
if self._cfg.dyn_discrete: | |
feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter | |
else: | |
feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter | |
self.heads["image"] = ConvDecoder( | |
feat_size, # pytorch version | |
self._cfg.cnn_depth, | |
self._cfg.act, | |
self._cfg.norm, | |
self.state_size, | |
self._cfg.decoder_kernels, | |
) | |
self.heads["reward"] = DenseHead( | |
feat_size, # dyn_stoch * dyn_discrete + dyn_deter | |
(255, ), | |
self._cfg.reward_layers, | |
self._cfg.units, | |
'SiLU', # self._cfg.act | |
'LN', # self._cfg.norm | |
dist=self._cfg.reward_head, | |
outscale=0.0, | |
device=self._cfg.device, | |
) | |
if self._cfg.pred_discount: | |
self.heads["discount"] = DenseHead( | |
feat_size, # pytorch version | |
[], | |
self._cfg.discount_layers, | |
self._cfg.units, | |
'SiLU', # self._cfg.act | |
'LN', # self._cfg.norm | |
dist="binary", | |
device=self._cfg.device, | |
) | |
if self._cuda: | |
self.cuda() | |
# to do | |
# grad_clip, weight_decay | |
self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr) | |
def step(self, obs, act): | |
pass | |
def eval(self, env_buffer, envstep, train_iter): | |
pass | |
def should_pretrain(self): | |
if self.pretrain_flag: | |
self.pretrain_flag = False | |
return True | |
return False | |
def train(self, env_buffer, envstep, train_iter, batch_size, batch_length): | |
self.last_train_step = envstep | |
data = env_buffer.sample( | |
batch_size, batch_length, train_iter | |
) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]] | |
data = default_collate(data) # -> [len=T, ele={dict_key: Tensor(B, any_dims)}] | |
data = lists_to_dicts(data, recursive=True) # -> {some_key: T lists}, each list is [B, some_dim] | |
data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])} | |
data['discount'] = data.get('discount', 1.0 - data['done'].float()) | |
data['discount'] *= 0.997 | |
data['weight'] = data.get('weight', None) | |
data['image'] = data['obs'] - 0.5 | |
data = to_device(data, self._cfg.device) | |
if len(data['reward'].shape) == 2: | |
data['reward'] = data['reward'].unsqueeze(-1) | |
if len(data['action'].shape) == 2: | |
data['action'] = data['action'].unsqueeze(-1) | |
if len(data['discount'].shape) == 2: | |
data['discount'] = data['discount'].unsqueeze(-1) | |
self.requires_grad_(requires_grad=True) | |
image = data['image'].reshape([-1] + list(data['image'].shape[-3:])) | |
embed = self.encoder(image) | |
embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]]) | |
post, prior = self.dynamics.observe(embed, data["action"]) | |
kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss( | |
post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale | |
) | |
losses = {} | |
likes = {} | |
for name, head in self.heads.items(): | |
grad_head = name in self._cfg.grad_heads | |
feat = self.dynamics.get_feat(post) | |
feat = feat if grad_head else feat.detach() | |
pred = head(feat) | |
like = pred.log_prob(data[name]) | |
likes[name] = like | |
losses[name] = -torch.mean(like) | |
model_loss = sum(losses.values()) + kl_loss | |
# ==================== | |
# world model update | |
# ==================== | |
self.optimizer.zero_grad() | |
model_loss.backward() | |
self.optimizer.step() | |
self.requires_grad_(requires_grad=False) | |
# log | |
if self.tb_logger is not None: | |
for name, loss in losses.items(): | |
self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy().item(), envstep) | |
self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep) | |
self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep) | |
self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep) | |
self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy().item(), envstep) | |
self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy().item(), envstep) | |
self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy().item(), envstep) | |
prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy() | |
post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy() | |
self.tb_logger.add_scalar('prior_ent', prior_ent.item(), envstep) | |
self.tb_logger.add_scalar('post_ent', post_ent.item(), envstep) | |
context = dict( | |
embed=embed, | |
feat=self.dynamics.get_feat(post), | |
kl=kl_value, | |
postent=self.dynamics.get_dist(post).entropy(), | |
) | |
post = {k: v.detach() for k, v in post.items()} | |
return post, context | |
def _save_states(self, ): | |
self._states = copy.deepcopy(self.state_dict()) | |
def _save_state(self, id): | |
state_dict = self.state_dict() | |
for k, v in state_dict.items(): | |
if 'weight' in k or 'bias' in k: | |
self._states[k].data[id] = copy.deepcopy(v.data[id]) | |
def _load_states(self): | |
self.load_state_dict(self._states) | |
def _save_best(self, epoch, holdout_losses): | |
updated = False | |
for i in range(len(holdout_losses)): | |
current = holdout_losses[i] | |
_, best = self._snapshots[i] | |
improvement = (best - current) / best | |
if improvement > 0.01: | |
self._snapshots[i] = (epoch, current) | |
self._save_state(i) | |
# self._save_state(i) | |
updated = True | |
# improvement = (best - current) / best | |
if updated: | |
self._epochs_since_update = 0 | |
else: | |
self._epochs_since_update += 1 | |
return self._epochs_since_update > self.max_epochs_since_update | |