Wyatt-Huang
commited on
Upload 10 files
Browse files- README.md +62 -0
- agent/DiPo.py +173 -0
- agent/diffusion.py +178 -0
- agent/helpers.py +130 -0
- agent/model.py +100 -0
- agent/replay_memory.py +79 -0
- agent/vae.py +95 -0
- main.py +163 -0
- requirements.txt +103 -0
- run_dipo +25 -0
README.md
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Policy Representation via Diffusion Probability Model for Reinforcement Learning
|
2 |
+
|
3 |
+
**Policy Representation via Diffusion Probability Model for Reinforcement Learning**<br>
|
4 |
+
Anonymous <br>
|
5 |
+
|
6 |
+
Abstract: *Popular reinforcement learning (RL) algorithms tend to produce a unimodal policy distribution, which weakens the expressiveness of complicated policy and decays the ability of exploration. The diffusion probability model is powerful to learn complicated multimodal distributions, which has shown promising and potential applications to RL. In this paper, we formally build a theoretical foundation of policy representation via the diffusion probability model and provide practical implementations of diffusion policy for online model-free RL. Concretely, we character diffusion policy as a stochastic process, which is a new approach to representing a policy. Then we present a convergence guarantee for diffusion policy, which provides a theory to understand the multimodality of diffusion policy. Furthermore, we propose the DIPO which is an implementation for model-free online RL with \textbf{DI}ffusion \textbf{PO}licy. To the best of our knowledge, DIPO is the first algorithm to solve model-free online RL problems with the diffusion model. Finally, extensive empirical results show the effectiveness and superiority of DIPO on the standard continuous control MoJoCo benchmark.*
|
7 |
+
|
8 |
+
## Experiments
|
9 |
+
|
10 |
+
### Requirements
|
11 |
+
Installations of [PyTorch](https://pytorch.org/) and [MuJoCo](https://github.com/deepmind/mujoco) are needed.
|
12 |
+
A suitable [conda](https://conda.io) environment named `DIPO` can be created and activated with:
|
13 |
+
```.bash
|
14 |
+
conda create DIPO
|
15 |
+
conda activate DIPO
|
16 |
+
```
|
17 |
+
To get started, install the additionally required python packages into you environment.
|
18 |
+
```.bash
|
19 |
+
pip install -r requirements.txt
|
20 |
+
```
|
21 |
+
|
22 |
+
### Running
|
23 |
+
Running experiments based our code could be quite easy, so below we use `Hopper-v3` task as an example.
|
24 |
+
|
25 |
+
```.bash
|
26 |
+
python main.py --env_name Hopper-v3 --num_steps 1000000 --n_timesteps 100 --cuda 0 --seed 0
|
27 |
+
```
|
28 |
+
|
29 |
+
|
30 |
+
### Hyperparameters
|
31 |
+
Hyperparameters for DIPO have been shown as follow for easily reproducing our reported results.
|
32 |
+
|
33 |
+
#### Hyper-parameters for algorithms
|
34 |
+
| Hyperparameter | DIPO | SAC | TD3 | PPO |
|
35 |
+
| -------------- | ---- | --- | --- | --- |
|
36 |
+
| No. of hidden layers | 2 | 2 | 2 | 2 |
|
37 |
+
| No. of hidden nodes | 256 | 256 | 256 | 256 |
|
38 |
+
| Activation | mish | relu | relu | tanh |
|
39 |
+
| Batch size | 256 | 256 | 256 | 256 |
|
40 |
+
| Discount for reward $\gamma$ | 0.99 | 0.99 | 0.99 | 0.99 |
|
41 |
+
| Target smoothing coefficient $\tau$ | 0.005 | 0.005 | 0.005 | 0.005 |
|
42 |
+
| Learning rate for actor | $3 × 10^{-4}$ | $3 × 10^{-4}$ | $3 × 10^{-4}$ | $7 × 10^{-4}$ |
|
43 |
+
| Learning rate for actor | $3 × 10^{-4}$ | $3 × 10^{-4}$ | $3 × 10^{-4}$ | $7 × 10^{-4}$ |
|
44 |
+
| Actor Critic grad norm | 2 | N/A | N/A | 0.5 |
|
45 |
+
| Memeroy size | $1 × 10^6$ | $1 × 10^6$ | $1 × 10^6$ | $1 × 10^6$ |
|
46 |
+
| Entropy coefficient | N/A | 0.2 | N/A | 0.01 |
|
47 |
+
| Value loss coefficient | N/A | N/A | N/A | 0.5 |
|
48 |
+
| Exploration noise | N/A | N/A | $\mathcal{N}$(0, 0.1) | N/A |
|
49 |
+
| Policy noise | N/A | N/A | $\mathcal{N}$(0, 0.2) | N/A |
|
50 |
+
| Noise clip | N/A | N/A | 0.5 | N/A |
|
51 |
+
| Use gae | N/A | N/A | N/A | True |
|
52 |
+
|
53 |
+
#### Hyper-parameters for MuJoCo.(DIPO)
|
54 |
+
| Hyperparameter | Hopper-v3 | Walker2d-v3 | Ant-v3 | HalfCheetah-v3 | Humanoid-v3 |
|
55 |
+
| --- | --- | --- | --- | --- | --- |
|
56 |
+
| Learning rate for action | 0.03 | 0.03 | 0.03 | 0.03 | 0.03 |
|
57 |
+
| Actor Critic grad norm | 1 | 2 | 0.8 | 2 | 2 |
|
58 |
+
| Action grad norm ratio | 0.3 | 0.08 | 0.1 | 0.08 | 0.1 |
|
59 |
+
| Action gradient steps | 20 | 20 | 20 | 40 | 20 |
|
60 |
+
| Diffusion inference timesteps | 100 | 100 | 100 | 100 | 100 |
|
61 |
+
| Diffusion beta schedule | cosine | cosine | cosine | cosine | cosine |
|
62 |
+
| Update actor target every | 1 | 1 | 1 | 2 | 1 |
|
agent/DiPo.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
8 |
+
|
9 |
+
from agent.model import MLP, Critic
|
10 |
+
from agent.diffusion import Diffusion
|
11 |
+
from agent.vae import VAE
|
12 |
+
from agent.helpers import EMA
|
13 |
+
|
14 |
+
|
15 |
+
class DiPo(object):
|
16 |
+
def __init__(self,
|
17 |
+
args,
|
18 |
+
state_dim,
|
19 |
+
action_space,
|
20 |
+
memory,
|
21 |
+
diffusion_memory,
|
22 |
+
device,
|
23 |
+
):
|
24 |
+
action_dim = np.prod(action_space.shape)
|
25 |
+
|
26 |
+
self.policy_type = args.policy_type
|
27 |
+
if self.policy_type == 'Diffusion':
|
28 |
+
self.actor = Diffusion(state_dim=state_dim, action_dim=action_dim, noise_ratio=args.noise_ratio,
|
29 |
+
beta_schedule=args.beta_schedule, n_timesteps=args.n_timesteps).to(device)
|
30 |
+
elif self.policy_type == 'VAE':
|
31 |
+
self.actor = VAE(state_dim=state_dim, action_dim=action_dim, device=device).to(device)
|
32 |
+
else:
|
33 |
+
self.actor = MLP(state_dim=state_dim, action_dim=action_dim).to(device)
|
34 |
+
|
35 |
+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=args.diffusion_lr, eps=1e-5)
|
36 |
+
|
37 |
+
self.memory = memory
|
38 |
+
self.diffusion_memory = diffusion_memory
|
39 |
+
self.action_gradient_steps = args.action_gradient_steps
|
40 |
+
|
41 |
+
self.action_grad_norm = action_dim * args.ratio
|
42 |
+
self.ac_grad_norm = args.ac_grad_norm
|
43 |
+
|
44 |
+
self.step = 0
|
45 |
+
self.tau = args.tau
|
46 |
+
self.actor_target = copy.deepcopy(self.actor)
|
47 |
+
self.update_actor_target_every = args.update_actor_target_every
|
48 |
+
|
49 |
+
self.critic = Critic(state_dim, action_dim).to(device)
|
50 |
+
self.critic_target = copy.deepcopy(self.critic)
|
51 |
+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=args.critic_lr, eps=1e-5)
|
52 |
+
|
53 |
+
self.action_dim = action_dim
|
54 |
+
|
55 |
+
self.action_lr = args.action_lr
|
56 |
+
|
57 |
+
self.device = device
|
58 |
+
|
59 |
+
if action_space is None:
|
60 |
+
self.action_scale = 1.
|
61 |
+
self.action_bias = 0.
|
62 |
+
else:
|
63 |
+
self.action_scale = (action_space.high - action_space.low) / 2.
|
64 |
+
self.action_bias = (action_space.high + action_space.low) / 2.
|
65 |
+
|
66 |
+
def append_memory(self, state, action, reward, next_state, mask):
|
67 |
+
action = (action - self.action_bias) / self.action_scale
|
68 |
+
|
69 |
+
self.memory.append(state, action, reward, next_state, mask)
|
70 |
+
self.diffusion_memory.append(state, action)
|
71 |
+
|
72 |
+
def sample_action(self, state, eval=False):
|
73 |
+
state = torch.FloatTensor(state.reshape(1, -1)).to(self.device)
|
74 |
+
|
75 |
+
action = self.actor(state, eval).cpu().data.numpy().flatten()
|
76 |
+
action = action.clip(-1, 1)
|
77 |
+
action = action * self.action_scale + self.action_bias
|
78 |
+
return action
|
79 |
+
|
80 |
+
def action_gradient(self, batch_size, log_writer):
|
81 |
+
states, best_actions, idxs = self.diffusion_memory.sample(batch_size)
|
82 |
+
|
83 |
+
actions_optim = torch.optim.Adam([best_actions], lr=self.action_lr, eps=1e-5)
|
84 |
+
|
85 |
+
|
86 |
+
for i in range(self.action_gradient_steps):
|
87 |
+
best_actions.requires_grad_(True)
|
88 |
+
q1, q2 = self.critic(states, best_actions)
|
89 |
+
loss = -torch.min(q1, q2)
|
90 |
+
|
91 |
+
actions_optim.zero_grad()
|
92 |
+
|
93 |
+
loss.backward(torch.ones_like(loss))
|
94 |
+
if self.action_grad_norm > 0:
|
95 |
+
actions_grad_norms = nn.utils.clip_grad_norm_([best_actions], max_norm=self.action_grad_norm, norm_type=2)
|
96 |
+
|
97 |
+
actions_optim.step()
|
98 |
+
|
99 |
+
best_actions.requires_grad_(False)
|
100 |
+
best_actions.clamp_(-1., 1.)
|
101 |
+
|
102 |
+
# if self.step % 10 == 0:
|
103 |
+
# log_writer.add_scalar('Action Grad Norm', actions_grad_norms.max().item(), self.step)
|
104 |
+
|
105 |
+
best_actions = best_actions.detach()
|
106 |
+
|
107 |
+
self.diffusion_memory.replace(idxs, best_actions.cpu().numpy())
|
108 |
+
|
109 |
+
return states, best_actions
|
110 |
+
|
111 |
+
def train(self, iterations, batch_size=256, log_writer=None):
|
112 |
+
for _ in range(iterations):
|
113 |
+
# Sample replay buffer / batch
|
114 |
+
states, actions, rewards, next_states, masks = self.memory.sample(batch_size)
|
115 |
+
|
116 |
+
""" Q Training """
|
117 |
+
current_q1, current_q2 = self.critic(states, actions)
|
118 |
+
|
119 |
+
next_actions = self.actor_target(next_states, eval=False)
|
120 |
+
target_q1, target_q2 = self.critic_target(next_states, next_actions)
|
121 |
+
target_q = torch.min(target_q1, target_q2)
|
122 |
+
|
123 |
+
target_q = (rewards + masks * target_q).detach()
|
124 |
+
|
125 |
+
critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
|
126 |
+
|
127 |
+
self.critic_optimizer.zero_grad()
|
128 |
+
critic_loss.backward()
|
129 |
+
if self.ac_grad_norm > 0:
|
130 |
+
critic_grad_norms = nn.utils.clip_grad_norm_(self.critic.parameters(), max_norm=self.ac_grad_norm, norm_type=2)
|
131 |
+
# if self.step % 10 == 0:
|
132 |
+
# log_writer.add_scalar('Critic Grad Norm', critic_grad_norms.max().item(), self.step)
|
133 |
+
self.critic_optimizer.step()
|
134 |
+
|
135 |
+
""" Policy Training """
|
136 |
+
states, best_actions = self.action_gradient(batch_size, log_writer)
|
137 |
+
|
138 |
+
actor_loss = self.actor.loss(best_actions, states)
|
139 |
+
|
140 |
+
self.actor_optimizer.zero_grad()
|
141 |
+
actor_loss.backward()
|
142 |
+
if self.ac_grad_norm > 0:
|
143 |
+
actor_grad_norms = nn.utils.clip_grad_norm_(self.actor.parameters(), max_norm=self.ac_grad_norm, norm_type=2)
|
144 |
+
# if self.step % 10 == 0:
|
145 |
+
# log_writer.add_scalar('Actor Grad Norm', actor_grad_norms.max().item(), self.step)
|
146 |
+
self.actor_optimizer.step()
|
147 |
+
|
148 |
+
""" Step Target network """
|
149 |
+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
|
150 |
+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
151 |
+
|
152 |
+
if self.step % self.update_actor_target_every == 0:
|
153 |
+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
|
154 |
+
target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
|
155 |
+
|
156 |
+
self.step += 1
|
157 |
+
|
158 |
+
def save_model(self, dir, id=None):
|
159 |
+
if id is not None:
|
160 |
+
torch.save(self.actor.state_dict(), f'{dir}/actor_{id}.pth')
|
161 |
+
torch.save(self.critic.state_dict(), f'{dir}/critic_{id}.pth')
|
162 |
+
else:
|
163 |
+
torch.save(self.actor.state_dict(), f'{dir}/actor.pth')
|
164 |
+
torch.save(self.critic.state_dict(), f'{dir}/critic.pth')
|
165 |
+
|
166 |
+
def load_model(self, dir, id=None):
|
167 |
+
if id is not None:
|
168 |
+
self.actor.load_state_dict(torch.load(f'{dir}/actor_{id}.pth'))
|
169 |
+
self.critic.load_state_dict(torch.load(f'{dir}/critic_{id}.pth'))
|
170 |
+
else:
|
171 |
+
self.actor.load_state_dict(torch.load(f'{dir}/actor.pth'))
|
172 |
+
self.critic.load_state_dict(torch.load(f'{dir}/critic.pth'))
|
173 |
+
|
agent/diffusion.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
from agent.helpers import (cosine_beta_schedule,
|
9 |
+
linear_beta_schedule,
|
10 |
+
vp_beta_schedule,
|
11 |
+
extract,
|
12 |
+
Losses)
|
13 |
+
|
14 |
+
from agent.model import Model
|
15 |
+
|
16 |
+
|
17 |
+
class Diffusion(nn.Module):
|
18 |
+
def __init__(self, state_dim, action_dim, noise_ratio,
|
19 |
+
beta_schedule='vp', n_timesteps=1000,
|
20 |
+
loss_type='l2', clip_denoised=True, predict_epsilon=True):
|
21 |
+
super(Diffusion, self).__init__()
|
22 |
+
|
23 |
+
self.state_dim = state_dim
|
24 |
+
self.action_dim = action_dim
|
25 |
+
self.model = Model(state_dim, action_dim)
|
26 |
+
|
27 |
+
self.max_noise_ratio = noise_ratio
|
28 |
+
self.noise_ratio = noise_ratio
|
29 |
+
|
30 |
+
if beta_schedule == 'linear':
|
31 |
+
betas = linear_beta_schedule(n_timesteps)
|
32 |
+
elif beta_schedule == 'cosine':
|
33 |
+
betas = cosine_beta_schedule(n_timesteps)
|
34 |
+
elif beta_schedule == 'vp':
|
35 |
+
betas = vp_beta_schedule(n_timesteps)
|
36 |
+
|
37 |
+
alphas = 1. - betas
|
38 |
+
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
39 |
+
alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
|
40 |
+
|
41 |
+
self.n_timesteps = int(n_timesteps)
|
42 |
+
self.clip_denoised = clip_denoised
|
43 |
+
self.predict_epsilon = predict_epsilon
|
44 |
+
|
45 |
+
self.register_buffer('betas', betas)
|
46 |
+
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
47 |
+
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
48 |
+
|
49 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
50 |
+
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
51 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
52 |
+
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
53 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
54 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
55 |
+
|
56 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
57 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
58 |
+
self.register_buffer('posterior_variance', posterior_variance)
|
59 |
+
|
60 |
+
## log calculation clipped because the posterior variance
|
61 |
+
## is 0 at the beginning of the diffusion chain
|
62 |
+
self.register_buffer('posterior_log_variance_clipped',
|
63 |
+
torch.log(torch.clamp(posterior_variance, min=1e-20)))
|
64 |
+
self.register_buffer('posterior_mean_coef1',
|
65 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
66 |
+
self.register_buffer('posterior_mean_coef2',
|
67 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))
|
68 |
+
|
69 |
+
self.loss_fn = Losses[loss_type]()
|
70 |
+
|
71 |
+
# ------------------------------------------ sampling ------------------------------------------#
|
72 |
+
|
73 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
74 |
+
'''
|
75 |
+
if self.predict_epsilon, model output is (scaled) noise;
|
76 |
+
otherwise, model predicts x0 directly
|
77 |
+
'''
|
78 |
+
if self.predict_epsilon:
|
79 |
+
return (
|
80 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
81 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
return noise
|
85 |
+
|
86 |
+
def q_posterior(self, x_start, x_t, t):
|
87 |
+
posterior_mean = (
|
88 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
89 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
90 |
+
)
|
91 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
92 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
93 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
94 |
+
|
95 |
+
def p_mean_variance(self, x, t, s):
|
96 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=self.model(x, t, s))
|
97 |
+
|
98 |
+
if self.clip_denoised:
|
99 |
+
x_recon.clamp_(-1., 1.)
|
100 |
+
else:
|
101 |
+
assert RuntimeError()
|
102 |
+
|
103 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
104 |
+
return model_mean, posterior_variance, posterior_log_variance
|
105 |
+
|
106 |
+
@torch.no_grad()
|
107 |
+
def p_sample(self, x, t, s):
|
108 |
+
b, *_, device = *x.shape, x.device
|
109 |
+
|
110 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, s=s)
|
111 |
+
|
112 |
+
noise = torch.randn_like(x)
|
113 |
+
# no noise when t == 0
|
114 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
115 |
+
|
116 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise * self.noise_ratio
|
117 |
+
|
118 |
+
|
119 |
+
@torch.no_grad()
|
120 |
+
def p_sample_loop(self, state, shape):
|
121 |
+
device = self.betas.device
|
122 |
+
|
123 |
+
batch_size = shape[0]
|
124 |
+
x = torch.randn(shape, device=device)
|
125 |
+
|
126 |
+
for i in reversed(range(0, self.n_timesteps)):
|
127 |
+
timesteps = torch.full((batch_size,), i, device=device, dtype=torch.long)
|
128 |
+
x = self.p_sample(x, timesteps, state)
|
129 |
+
|
130 |
+
return x
|
131 |
+
|
132 |
+
@torch.no_grad()
|
133 |
+
def sample(self, state, eval=False):
|
134 |
+
self.noise_ratio = 0 if eval else self.max_noise_ratio
|
135 |
+
|
136 |
+
batch_size = state.shape[0]
|
137 |
+
shape = (batch_size, self.action_dim)
|
138 |
+
action = self.p_sample_loop(state, shape)
|
139 |
+
return action.clamp_(-1., 1.)
|
140 |
+
|
141 |
+
# ------------------------------------------ training ------------------------------------------#
|
142 |
+
|
143 |
+
def q_sample(self, x_start, t, noise=None):
|
144 |
+
if noise is None:
|
145 |
+
noise = torch.randn_like(x_start)
|
146 |
+
|
147 |
+
sample = (
|
148 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
149 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
150 |
+
)
|
151 |
+
|
152 |
+
return sample
|
153 |
+
|
154 |
+
def p_losses(self, x_start, state, t, weights=1.0):
|
155 |
+
noise = torch.randn_like(x_start)
|
156 |
+
|
157 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
158 |
+
|
159 |
+
x_recon = self.model(x_noisy, t, state)
|
160 |
+
|
161 |
+
assert noise.shape == x_recon.shape
|
162 |
+
|
163 |
+
if self.predict_epsilon:
|
164 |
+
loss = self.loss_fn(x_recon, noise, weights)
|
165 |
+
else:
|
166 |
+
loss = self.loss_fn(x_recon, x_start, weights)
|
167 |
+
|
168 |
+
return loss
|
169 |
+
|
170 |
+
|
171 |
+
def loss(self, x, state, weights=1.0):
|
172 |
+
batch_size = len(x)
|
173 |
+
t = torch.randint(0, self.n_timesteps, (batch_size,), device=x.device).long()
|
174 |
+
return self.p_losses(x, state, t, weights)
|
175 |
+
|
176 |
+
def forward(self, state, eval=False):
|
177 |
+
return self.sample(state, eval)
|
178 |
+
|
agent/helpers.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
def init_weights(m):
|
10 |
+
def truncated_normal_init(t, mean=0.0, std=0.01):
|
11 |
+
torch.nn.init.normal_(t, mean=mean, std=std)
|
12 |
+
while True:
|
13 |
+
cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
|
14 |
+
if not torch.sum(cond):
|
15 |
+
break
|
16 |
+
t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
|
17 |
+
return t
|
18 |
+
|
19 |
+
if type(m) == nn.Linear:
|
20 |
+
input_dim = m.in_features
|
21 |
+
truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(input_dim)))
|
22 |
+
m.bias.data.fill_(0.0)
|
23 |
+
|
24 |
+
|
25 |
+
class SinusoidalPosEmb(nn.Module):
|
26 |
+
def __init__(self, dim):
|
27 |
+
super().__init__()
|
28 |
+
self.dim = dim
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
device = x.device
|
32 |
+
half_dim = self.dim // 2
|
33 |
+
emb = math.log(10000) / (half_dim - 1)
|
34 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
35 |
+
emb = x[:, None] * emb[None, :]
|
36 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
37 |
+
return emb
|
38 |
+
|
39 |
+
#-----------------------------------------------------------------------------#
|
40 |
+
#---------------------------------- sampling ---------------------------------#
|
41 |
+
#-----------------------------------------------------------------------------#
|
42 |
+
|
43 |
+
|
44 |
+
def extract(a, t, x_shape):
|
45 |
+
b, *_ = t.shape
|
46 |
+
out = a.gather(-1, t)
|
47 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
48 |
+
|
49 |
+
|
50 |
+
def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
|
51 |
+
"""
|
52 |
+
cosine schedule
|
53 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
54 |
+
"""
|
55 |
+
steps = timesteps + 1
|
56 |
+
x = np.linspace(0, steps, steps)
|
57 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
58 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
59 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
60 |
+
betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
|
61 |
+
return torch.tensor(betas_clipped, dtype=dtype)
|
62 |
+
|
63 |
+
|
64 |
+
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=2e-2, dtype=torch.float32):
|
65 |
+
betas = np.linspace(
|
66 |
+
beta_start, beta_end, timesteps
|
67 |
+
)
|
68 |
+
return torch.tensor(betas, dtype=dtype)
|
69 |
+
|
70 |
+
|
71 |
+
def vp_beta_schedule(timesteps, dtype=torch.float32):
|
72 |
+
t = np.arange(1, timesteps + 1)
|
73 |
+
T = timesteps
|
74 |
+
b_max = 10.
|
75 |
+
b_min = 0.1
|
76 |
+
alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2)
|
77 |
+
betas = 1 - alpha
|
78 |
+
return torch.tensor(betas, dtype=dtype)
|
79 |
+
|
80 |
+
#-----------------------------------------------------------------------------#
|
81 |
+
#---------------------------------- losses -----------------------------------#
|
82 |
+
#-----------------------------------------------------------------------------#
|
83 |
+
|
84 |
+
class WeightedLoss(nn.Module):
|
85 |
+
|
86 |
+
def __init__(self):
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
def forward(self, pred, targ, weights=1.0):
|
90 |
+
'''
|
91 |
+
pred, targ : tensor [ batch_size x action_dim ]
|
92 |
+
'''
|
93 |
+
loss = self._loss(pred, targ)
|
94 |
+
weighted_loss = (loss * weights).mean()
|
95 |
+
return weighted_loss
|
96 |
+
|
97 |
+
class WeightedL1(WeightedLoss):
|
98 |
+
|
99 |
+
def _loss(self, pred, targ):
|
100 |
+
return torch.abs(pred - targ)
|
101 |
+
|
102 |
+
class WeightedL2(WeightedLoss):
|
103 |
+
|
104 |
+
def _loss(self, pred, targ):
|
105 |
+
return F.mse_loss(pred, targ, reduction='none')
|
106 |
+
|
107 |
+
|
108 |
+
Losses = {
|
109 |
+
'l1': WeightedL1,
|
110 |
+
'l2': WeightedL2,
|
111 |
+
}
|
112 |
+
|
113 |
+
|
114 |
+
class EMA():
|
115 |
+
'''
|
116 |
+
empirical moving average
|
117 |
+
'''
|
118 |
+
def __init__(self, beta):
|
119 |
+
super().__init__()
|
120 |
+
self.beta = beta
|
121 |
+
|
122 |
+
def update_model_average(self, ma_model, current_model):
|
123 |
+
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
124 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
125 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
126 |
+
|
127 |
+
def update_average(self, old, new):
|
128 |
+
if old is None:
|
129 |
+
return new
|
130 |
+
return old * self.beta + (1 - self.beta) * new
|
agent/model.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from agent.helpers import SinusoidalPosEmb, init_weights
|
7 |
+
|
8 |
+
|
9 |
+
class Critic(nn.Module):
|
10 |
+
def __init__(self, state_dim, action_dim, hidden_dim=256):
|
11 |
+
super(Critic, self).__init__()
|
12 |
+
self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
|
13 |
+
nn.Mish(),
|
14 |
+
nn.Linear(hidden_dim, hidden_dim),
|
15 |
+
nn.Mish(),
|
16 |
+
nn.Linear(hidden_dim, hidden_dim),
|
17 |
+
nn.Mish(),
|
18 |
+
nn.Linear(hidden_dim, 1))
|
19 |
+
|
20 |
+
self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
|
21 |
+
nn.Mish(),
|
22 |
+
nn.Linear(hidden_dim, hidden_dim),
|
23 |
+
nn.Mish(),
|
24 |
+
nn.Linear(hidden_dim, hidden_dim),
|
25 |
+
nn.Mish(),
|
26 |
+
nn.Linear(hidden_dim, 1))
|
27 |
+
|
28 |
+
self.apply(init_weights)
|
29 |
+
|
30 |
+
def forward(self, state, action):
|
31 |
+
x = torch.cat([state, action], dim=-1)
|
32 |
+
return self.q1_model(x), self.q2_model(x)
|
33 |
+
|
34 |
+
def q1(self, state, action):
|
35 |
+
x = torch.cat([state, action], dim=-1)
|
36 |
+
return self.q1_model(x)
|
37 |
+
|
38 |
+
def q_min(self, state, action):
|
39 |
+
q1, q2 = self.forward(state, action)
|
40 |
+
return torch.min(q1, q2)
|
41 |
+
|
42 |
+
|
43 |
+
class Model(nn.Module):
|
44 |
+
def __init__(self, state_dim, action_dim, hidden_size=256, time_dim=32):
|
45 |
+
super(Model, self).__init__()
|
46 |
+
|
47 |
+
self.time_mlp = nn.Sequential(
|
48 |
+
SinusoidalPosEmb(time_dim),
|
49 |
+
nn.Linear(time_dim, hidden_size),
|
50 |
+
nn.Mish(),
|
51 |
+
nn.Linear(hidden_size, time_dim),
|
52 |
+
)
|
53 |
+
|
54 |
+
input_dim = state_dim + action_dim + time_dim
|
55 |
+
self.layer = nn.Sequential(nn.Linear(input_dim, hidden_size),
|
56 |
+
nn.Mish(),
|
57 |
+
nn.Linear(hidden_size, hidden_size),
|
58 |
+
nn.Mish(),
|
59 |
+
nn.Linear(hidden_size, hidden_size),
|
60 |
+
nn.Mish(),
|
61 |
+
nn.Linear(hidden_size, action_dim))
|
62 |
+
self.apply(init_weights)
|
63 |
+
|
64 |
+
|
65 |
+
def forward(self, x, time, state):
|
66 |
+
|
67 |
+
t = self.time_mlp(time)
|
68 |
+
out = torch.cat([x, t, state], dim=-1)
|
69 |
+
out = self.layer(out)
|
70 |
+
|
71 |
+
return out
|
72 |
+
|
73 |
+
|
74 |
+
class MLP(nn.Module):
|
75 |
+
def __init__(self, state_dim, action_dim, hidden_size=256):
|
76 |
+
super(MLP, self).__init__()
|
77 |
+
|
78 |
+
input_dim = state_dim
|
79 |
+
self.mid_layer = nn.Sequential(nn.Linear(input_dim, hidden_size),
|
80 |
+
nn.Mish(),
|
81 |
+
nn.Linear(hidden_size, hidden_size),
|
82 |
+
nn.Mish(),
|
83 |
+
nn.Linear(hidden_size, hidden_size),
|
84 |
+
nn.Mish())
|
85 |
+
|
86 |
+
self.final_layer = nn.Linear(hidden_size, action_dim)
|
87 |
+
|
88 |
+
self.apply(init_weights)
|
89 |
+
|
90 |
+
def forward(self, state, eval=False):
|
91 |
+
out = self.mid_layer(state)
|
92 |
+
out = self.final_layer(out)
|
93 |
+
|
94 |
+
if not eval:
|
95 |
+
out += torch.randn_like(out) * 0.1
|
96 |
+
|
97 |
+
return out
|
98 |
+
|
99 |
+
def loss(self, action, state):
|
100 |
+
return F.mse_loss(self.forward(state), action, reduction='mean')
|
agent/replay_memory.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ReplayMemory():
|
6 |
+
"""Buffer to store environment transitions."""
|
7 |
+
def __init__(self, state_dim, action_dim, capacity, device):
|
8 |
+
self.capacity = int(capacity)
|
9 |
+
self.device = device
|
10 |
+
|
11 |
+
self.states = np.empty((self.capacity, int(state_dim)), dtype=np.float32)
|
12 |
+
self.actions = np.empty((self.capacity, int(action_dim)), dtype=np.float32)
|
13 |
+
self.rewards = np.empty((self.capacity, 1), dtype=np.float32)
|
14 |
+
self.next_states = np.empty((self.capacity, int(state_dim)), dtype=np.float32)
|
15 |
+
self.masks = np.empty((self.capacity, 1), dtype=np.float32)
|
16 |
+
|
17 |
+
self.idx = 0
|
18 |
+
self.full = False
|
19 |
+
|
20 |
+
def append(self, state, action, reward, next_state, mask):
|
21 |
+
|
22 |
+
np.copyto(self.states[self.idx], state)
|
23 |
+
np.copyto(self.actions[self.idx], action)
|
24 |
+
np.copyto(self.rewards[self.idx], reward)
|
25 |
+
np.copyto(self.next_states[self.idx], next_state)
|
26 |
+
np.copyto(self.masks[self.idx], mask)
|
27 |
+
|
28 |
+
self.idx = (self.idx + 1) % self.capacity
|
29 |
+
self.full = self.full or self.idx == 0
|
30 |
+
|
31 |
+
def sample(self, batch_size):
|
32 |
+
idxs = np.random.randint(
|
33 |
+
0, self.capacity if self.full else self.idx, size=batch_size
|
34 |
+
)
|
35 |
+
|
36 |
+
states = torch.as_tensor(self.states[idxs], device=self.device)
|
37 |
+
actions = torch.as_tensor(self.actions[idxs], device=self.device)
|
38 |
+
rewards = torch.as_tensor(self.rewards[idxs], device=self.device)
|
39 |
+
next_states = torch.as_tensor(self.next_states[idxs], device=self.device)
|
40 |
+
masks = torch.as_tensor(self.masks[idxs], device=self.device)
|
41 |
+
|
42 |
+
return states, actions, rewards, next_states, masks
|
43 |
+
|
44 |
+
|
45 |
+
class DiffusionMemory():
|
46 |
+
"""Buffer to store best actions."""
|
47 |
+
def __init__(self, state_dim, action_dim, capacity, device):
|
48 |
+
self.capacity = int(capacity)
|
49 |
+
self.device = device
|
50 |
+
|
51 |
+
self.states = np.empty((self.capacity, int(state_dim)), dtype=np.float32)
|
52 |
+
self.best_actions = np.empty((self.capacity, int(action_dim)), dtype=np.float32)
|
53 |
+
|
54 |
+
self.idx = 0
|
55 |
+
self.full = False
|
56 |
+
|
57 |
+
def append(self, state, action):
|
58 |
+
|
59 |
+
np.copyto(self.states[self.idx], state)
|
60 |
+
np.copyto(self.best_actions[self.idx], action)
|
61 |
+
|
62 |
+
self.idx = (self.idx + 1) % self.capacity
|
63 |
+
self.full = self.full or self.idx == 0
|
64 |
+
|
65 |
+
def sample(self, batch_size):
|
66 |
+
idxs = np.random.randint(
|
67 |
+
0, self.capacity if self.full else self.idx, size=batch_size
|
68 |
+
)
|
69 |
+
|
70 |
+
states = torch.as_tensor(self.states[idxs], device=self.device)
|
71 |
+
best_actions = torch.as_tensor(self.best_actions[idxs], device=self.device)
|
72 |
+
|
73 |
+
best_actions.requires_grad_(True)
|
74 |
+
|
75 |
+
return states, best_actions, idxs
|
76 |
+
|
77 |
+
def replace(self, idxs, best_actions):
|
78 |
+
np.copyto(self.best_actions[idxs], best_actions)
|
79 |
+
|
agent/vae.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
from agent.helpers import init_weights
|
5 |
+
|
6 |
+
|
7 |
+
class VAE(nn.Module):
|
8 |
+
def __init__(self, state_dim, action_dim, device, hidden_size=256) -> None:
|
9 |
+
super(VAE, self).__init__()
|
10 |
+
|
11 |
+
self.hidden_size = hidden_size
|
12 |
+
self.action_dim = action_dim
|
13 |
+
|
14 |
+
input_dim = state_dim + action_dim
|
15 |
+
|
16 |
+
self.encoder = nn.Sequential(nn.Linear(input_dim, hidden_size),
|
17 |
+
nn.Mish(),
|
18 |
+
nn.Linear(hidden_size, hidden_size),
|
19 |
+
nn.Mish(),
|
20 |
+
nn.Linear(hidden_size, hidden_size),
|
21 |
+
nn.Mish())
|
22 |
+
|
23 |
+
self.fc_mu = nn.Linear(hidden_size, hidden_size)
|
24 |
+
self.fc_var = nn.Linear(hidden_size, hidden_size)
|
25 |
+
|
26 |
+
self.decoder = nn.Sequential(nn.Linear(hidden_size + state_dim, hidden_size),
|
27 |
+
nn.Mish(),
|
28 |
+
nn.Linear(hidden_size, hidden_size),
|
29 |
+
nn.Mish(),
|
30 |
+
nn.Linear(hidden_size, hidden_size),
|
31 |
+
nn.Mish())
|
32 |
+
|
33 |
+
self.final_layer = nn.Sequential(nn.Linear(hidden_size, action_dim))
|
34 |
+
|
35 |
+
self.apply(init_weights)
|
36 |
+
|
37 |
+
self.device = device
|
38 |
+
|
39 |
+
def encode(self, action, state):
|
40 |
+
x = torch.cat([action, state], dim=-1)
|
41 |
+
result = self.encoder(x)
|
42 |
+
result = torch.flatten(result, start_dim=1)
|
43 |
+
|
44 |
+
# Split the result into mu and var components
|
45 |
+
# of the latent Gaussian distribution
|
46 |
+
mu = self.fc_mu(result)
|
47 |
+
log_var = self.fc_var(result)
|
48 |
+
|
49 |
+
return mu, log_var
|
50 |
+
|
51 |
+
def decode(self, z, state):
|
52 |
+
x = torch.cat([z, state], dim=-1)
|
53 |
+
result = self.decoder(x)
|
54 |
+
result = self.final_layer(result)
|
55 |
+
return result
|
56 |
+
|
57 |
+
def reparameterize(self, mu, logvar):
|
58 |
+
"""
|
59 |
+
Will a single z be enough ti compute the expectation
|
60 |
+
for the loss??
|
61 |
+
:param mu: (Tensor) Mean of the latent Gaussian
|
62 |
+
:param logvar: (Tensor) Standard deviation of the latent Gaussian
|
63 |
+
:return:
|
64 |
+
"""
|
65 |
+
std = torch.exp(0.5 * logvar)
|
66 |
+
eps = torch.randn_like(std)
|
67 |
+
return eps * std + mu
|
68 |
+
|
69 |
+
def loss(self, action, state):
|
70 |
+
mu, log_var = self.encode(action, state)
|
71 |
+
z = self.reparameterize(mu, log_var)
|
72 |
+
recons = self.decode(z, state)
|
73 |
+
|
74 |
+
kld_weight = 0.1 # Account for the minibatch samples from the dataset
|
75 |
+
recons_loss = F.mse_loss(recons, action)
|
76 |
+
|
77 |
+
kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
|
78 |
+
|
79 |
+
# print('recons_loss: ', recons_loss)
|
80 |
+
# print('kld_loss: ', kld_loss)
|
81 |
+
|
82 |
+
loss = recons_loss + kld_weight * kld_loss
|
83 |
+
return loss
|
84 |
+
|
85 |
+
def forward(self, state, eval=False):
|
86 |
+
batch_size = state.shape[0]
|
87 |
+
shape = (batch_size, self.hidden_size)
|
88 |
+
|
89 |
+
if eval:
|
90 |
+
z = torch.zeros(shape, device=self.device)
|
91 |
+
else:
|
92 |
+
z = torch.randn(shape, device=self.device)
|
93 |
+
samples = self.decode(z, state)
|
94 |
+
|
95 |
+
return samples.clamp(-1., 1.)
|
main.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from agent.DiPo import DiPo
|
6 |
+
from agent.replay_memory import ReplayMemory, DiffusionMemory
|
7 |
+
|
8 |
+
from tensorboardX import SummaryWriter
|
9 |
+
import gym
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
def readParser():
|
14 |
+
parser = argparse.ArgumentParser(description='Diffusion Policy')
|
15 |
+
parser.add_argument('--env_name', default="Hopper-v3",
|
16 |
+
help='Mujoco Gym environment (default: Hopper-v3)')
|
17 |
+
parser.add_argument('--seed', type=int, default=0, metavar='N',
|
18 |
+
help='random seed (default: 0)')
|
19 |
+
|
20 |
+
parser.add_argument('--num_steps', type=int, default=1000000, metavar='N',
|
21 |
+
help='env timesteps (default: 1000000)')
|
22 |
+
|
23 |
+
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
|
24 |
+
help='batch size (default: 256)')
|
25 |
+
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
|
26 |
+
help='discount factor for reward (default: 0.99)')
|
27 |
+
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
|
28 |
+
help='target smoothing coefficient(τ) (default: 0.005)')
|
29 |
+
parser.add_argument('--update_actor_target_every', type=int, default=1, metavar='N',
|
30 |
+
help='update actor target per iteration (default: 1)')
|
31 |
+
|
32 |
+
parser.add_argument("--policy_type", type=str, default="Diffusion", metavar='S',
|
33 |
+
help="Diffusion, VAE or MLP")
|
34 |
+
parser.add_argument("--beta_schedule", type=str, default="cosine", metavar='S',
|
35 |
+
help="linear, cosine or vp")
|
36 |
+
parser.add_argument('--n_timesteps', type=int, default=100, metavar='N',
|
37 |
+
help='diffusion timesteps (default: 100)')
|
38 |
+
parser.add_argument('--diffusion_lr', type=float, default=0.0003, metavar='G',
|
39 |
+
help='diffusion learning rate (default: 0.0003)')
|
40 |
+
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
|
41 |
+
help='critic learning rate (default: 0.0003)')
|
42 |
+
parser.add_argument('--action_lr', type=float, default=0.03, metavar='G',
|
43 |
+
help='diffusion learning rate (default: 0.03)')
|
44 |
+
parser.add_argument('--noise_ratio', type=float, default=1.0, metavar='G',
|
45 |
+
help='noise ratio in sample process (default: 1.0)')
|
46 |
+
|
47 |
+
parser.add_argument('--action_gradient_steps', type=int, default=20, metavar='N',
|
48 |
+
help='action gradient steps (default: 20)')
|
49 |
+
parser.add_argument('--ratio', type=float, default=0.1, metavar='G',
|
50 |
+
help='the ratio of action grad norm to action_dim (default: 0.1)')
|
51 |
+
parser.add_argument('--ac_grad_norm', type=float, default=2.0, metavar='G',
|
52 |
+
help='actor and critic grad norm (default: 1.0)')
|
53 |
+
|
54 |
+
parser.add_argument('--cuda', default='cuda:0',
|
55 |
+
help='run on CUDA (default: cuda:0)')
|
56 |
+
|
57 |
+
return parser.parse_args()
|
58 |
+
|
59 |
+
|
60 |
+
def evaluate(env, agent, writer, steps):
|
61 |
+
episodes = 10
|
62 |
+
returns = np.zeros((episodes,), dtype=np.float32)
|
63 |
+
|
64 |
+
for i in range(episodes):
|
65 |
+
state = env.reset()
|
66 |
+
episode_reward = 0.
|
67 |
+
done = False
|
68 |
+
while not done:
|
69 |
+
action = agent.sample_action(state, eval=True)
|
70 |
+
next_state, reward, done, _ = env.step(action)
|
71 |
+
episode_reward += reward
|
72 |
+
state = next_state
|
73 |
+
returns[i] = episode_reward
|
74 |
+
|
75 |
+
mean_return = np.mean(returns)
|
76 |
+
|
77 |
+
writer.add_scalar(
|
78 |
+
'reward/test', mean_return, steps)
|
79 |
+
print('-' * 60)
|
80 |
+
print(f'Num steps: {steps:<5} '
|
81 |
+
f'reward: {mean_return:<5.1f}')
|
82 |
+
print('-' * 60)
|
83 |
+
|
84 |
+
|
85 |
+
def main(args=None):
|
86 |
+
if args is None:
|
87 |
+
args = readParser()
|
88 |
+
|
89 |
+
device = torch.device(args.cuda)
|
90 |
+
|
91 |
+
dir = "record"
|
92 |
+
# dir = "test"
|
93 |
+
log_dir = os.path.join(dir, f'{args.env_name}', f'policy_type={args.policy_type}', f'ratio={args.ratio}', f'seed={args.seed}')
|
94 |
+
writer = SummaryWriter(log_dir)
|
95 |
+
|
96 |
+
# Initial environment
|
97 |
+
env = gym.make(args.env_name)
|
98 |
+
state_size = int(np.prod(env.observation_space.shape))
|
99 |
+
action_size = int(np.prod(env.action_space.shape))
|
100 |
+
print(action_size)
|
101 |
+
|
102 |
+
# Set random seed
|
103 |
+
torch.manual_seed(args.seed)
|
104 |
+
np.random.seed(args.seed)
|
105 |
+
env.seed(args.seed)
|
106 |
+
|
107 |
+
memory_size = 1e6
|
108 |
+
num_steps = args.num_steps
|
109 |
+
start_steps = 10000
|
110 |
+
eval_interval = 10000
|
111 |
+
updates_per_step = 1
|
112 |
+
batch_size = args.batch_size
|
113 |
+
log_interval = 10
|
114 |
+
|
115 |
+
memory = ReplayMemory(state_size, action_size, memory_size, device)
|
116 |
+
diffusion_memory = DiffusionMemory(state_size, action_size, memory_size, device)
|
117 |
+
|
118 |
+
agent = DiPo(args, state_size, env.action_space, memory, diffusion_memory, device)
|
119 |
+
|
120 |
+
steps = 0
|
121 |
+
episodes = 0
|
122 |
+
|
123 |
+
while steps < num_steps:
|
124 |
+
episode_reward = 0.
|
125 |
+
episode_steps = 0
|
126 |
+
done = False
|
127 |
+
state = env.reset()
|
128 |
+
episodes += 1
|
129 |
+
while not done:
|
130 |
+
if start_steps > steps:
|
131 |
+
action = env.action_space.sample()
|
132 |
+
else:
|
133 |
+
action = agent.sample_action(state, eval=False)
|
134 |
+
next_state, reward, done, _ = env.step(action)
|
135 |
+
|
136 |
+
mask = 0.0 if done else args.gamma
|
137 |
+
|
138 |
+
steps += 1
|
139 |
+
episode_steps += 1
|
140 |
+
episode_reward += reward
|
141 |
+
|
142 |
+
agent.append_memory(state, action, reward, next_state, mask)
|
143 |
+
|
144 |
+
if steps >= start_steps:
|
145 |
+
agent.train(updates_per_step, batch_size=batch_size, log_writer=writer)
|
146 |
+
|
147 |
+
if steps % eval_interval == 0:
|
148 |
+
evaluate(env, agent, writer, steps)
|
149 |
+
# self.save_models()
|
150 |
+
done =True
|
151 |
+
|
152 |
+
state = next_state
|
153 |
+
|
154 |
+
# if episodes % log_interval == 0:
|
155 |
+
# writer.add_scalar('reward/train', episode_reward, steps)
|
156 |
+
|
157 |
+
print(f'episode: {episodes:<4} '
|
158 |
+
f'episode steps: {episode_steps:<4} '
|
159 |
+
f'reward: {episode_reward:<5.1f}')
|
160 |
+
|
161 |
+
|
162 |
+
if __name__ == "__main__":
|
163 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.4.0
|
2 |
+
ale-py==0.8.1
|
3 |
+
asttokens==2.2.1
|
4 |
+
attrs==22.2.0
|
5 |
+
backcall==0.2.0
|
6 |
+
box2d-py==2.3.5
|
7 |
+
cachetools==5.3.0
|
8 |
+
certifi==2022.12.7
|
9 |
+
cffi==1.15.1
|
10 |
+
charset-normalizer==3.1.0
|
11 |
+
click==8.1.3
|
12 |
+
cloudpickle==2.2.1
|
13 |
+
cmake==3.26.0
|
14 |
+
contourpy==1.0.7
|
15 |
+
cycler==0.11.0
|
16 |
+
Cython==0.29.33
|
17 |
+
decorator==4.4.2
|
18 |
+
docopt==0.6.2
|
19 |
+
executing==1.2.0
|
20 |
+
fasteners==0.18
|
21 |
+
filelock==3.10.0
|
22 |
+
fonttools==4.39.2
|
23 |
+
glfw==2.5.7
|
24 |
+
grpcio==1.51.3
|
25 |
+
gym==0.21.0
|
26 |
+
gym-notices==0.0.8
|
27 |
+
h5py==3.8.0
|
28 |
+
idna==3.4
|
29 |
+
imageio==2.26.0
|
30 |
+
imageio-ffmpeg==0.4.8
|
31 |
+
importlib-metadata==4.13.0
|
32 |
+
importlib-resources==5.12.0
|
33 |
+
iniconfig==2.0.0
|
34 |
+
ipython==8.11.0
|
35 |
+
jedi==0.18.2
|
36 |
+
Jinja2==3.1.2
|
37 |
+
kiwisolver==1.4.4
|
38 |
+
labmaze==1.0.6
|
39 |
+
lit==15.0.7
|
40 |
+
lxml==4.9.2
|
41 |
+
lz4==4.3.2
|
42 |
+
Markdown==3.4.1
|
43 |
+
MarkupSafe==2.1.2
|
44 |
+
matplotlib==3.7.1
|
45 |
+
matplotlib-inline==0.1.6
|
46 |
+
mjrl @ git+https://github.com/aravindr93/mjrl@3871d93763d3b49c4741e6daeaebbc605fe140dc
|
47 |
+
moviepy==1.0.3
|
48 |
+
mpmath==1.3.0
|
49 |
+
mujoco==2.3.2
|
50 |
+
mujoco-py==2.1.2.14
|
51 |
+
networkx==3.0
|
52 |
+
numpy==1.24.2
|
53 |
+
oauthlib==3.2.2
|
54 |
+
packaging==23.0
|
55 |
+
pandas==1.5.3
|
56 |
+
parso==0.8.3
|
57 |
+
pexpect==4.8.0
|
58 |
+
pickleshare==0.7.5
|
59 |
+
Pillow==9.4.0
|
60 |
+
pipreqs==0.4.13
|
61 |
+
pluggy==1.0.0
|
62 |
+
proglog==0.1.10
|
63 |
+
prompt-toolkit==3.0.38
|
64 |
+
protobuf==3.20.3
|
65 |
+
ptyprocess==0.7.0
|
66 |
+
pure-eval==0.2.2
|
67 |
+
py==1.11.0
|
68 |
+
pyasn1==0.4.8
|
69 |
+
pyasn1-modules==0.2.8
|
70 |
+
pybullet==3.2.5
|
71 |
+
pycparser==2.21
|
72 |
+
pygame==2.1.0
|
73 |
+
Pygments==2.14.0
|
74 |
+
PyOpenGL==3.1.6
|
75 |
+
pyparsing==3.0.9
|
76 |
+
pytest==7.0.1
|
77 |
+
python-dateutil==2.8.2
|
78 |
+
pytz==2022.7.1
|
79 |
+
requests==2.28.2
|
80 |
+
requests-oauthlib==1.3.1
|
81 |
+
rsa==4.9
|
82 |
+
scipy==1.10.1
|
83 |
+
six==1.16.0
|
84 |
+
stable-baselines3==1.7.0
|
85 |
+
stack-data==0.6.2
|
86 |
+
swig==4.1.1
|
87 |
+
sympy==1.11.1
|
88 |
+
tensorboard==2.12.0
|
89 |
+
tensorboard-data-server==0.7.0
|
90 |
+
tensorboard-plugin-wit==1.8.1
|
91 |
+
tensorboardX==2.6
|
92 |
+
termcolor==2.3.0
|
93 |
+
tomli==2.0.1
|
94 |
+
torch==2.0.0
|
95 |
+
tqdm==4.65.0
|
96 |
+
traitlets==5.9.0
|
97 |
+
triton==2.0.0
|
98 |
+
typing_extensions==4.5.0
|
99 |
+
urllib3==1.26.15
|
100 |
+
wcwidth==0.2.6
|
101 |
+
Werkzeug==2.2.3
|
102 |
+
yarg==0.1.9
|
103 |
+
zipp==3.15.0
|
run_dipo
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Script to reproduce results
|
4 |
+
|
5 |
+
envs=(Hopper-v3 Walker2d-v3 Ant-v3 HalfCheetah-v3 Humanoid-v3)
|
6 |
+
steps=(1000000 1000000 3000000 3000000 10000000)
|
7 |
+
cnt=0
|
8 |
+
i=3
|
9 |
+
n_timesteps=100
|
10 |
+
for ((j=0;j<5;j+=1))
|
11 |
+
do
|
12 |
+
nohup python -u main.py \
|
13 |
+
--env_name ${envs[i]} \
|
14 |
+
--num_steps 1000000 \
|
15 |
+
--policy_type 'MLP' \
|
16 |
+
--beta_schedule 'cosine' \
|
17 |
+
--n_timesteps ${n_timesteps}\
|
18 |
+
--ratio 0.08 \
|
19 |
+
--ac_grad_norm 2 \
|
20 |
+
--action_gradient_steps 40 \
|
21 |
+
--update_actor_target_every 2 \
|
22 |
+
--seed $j \
|
23 |
+
--cuda "cuda:${cnt}" \
|
24 |
+
> "log/MLP-a_steps=40-%2-${envs[i]}-seed=${j}.log" 2>&1 &
|
25 |
+
done
|