|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from agent.helpers import SinusoidalPosEmb, init_weights |
|
|
|
|
|
class Critic(nn.Module): |
|
def __init__(self, state_dim, action_dim, hidden_dim=256): |
|
super(Critic, self).__init__() |
|
self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), |
|
nn.Mish(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.Mish(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.Mish(), |
|
nn.Linear(hidden_dim, 1)) |
|
|
|
self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim), |
|
nn.Mish(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.Mish(), |
|
nn.Linear(hidden_dim, hidden_dim), |
|
nn.Mish(), |
|
nn.Linear(hidden_dim, 1)) |
|
|
|
self.apply(init_weights) |
|
|
|
def forward(self, state, action): |
|
x = torch.cat([state, action], dim=-1) |
|
return self.q1_model(x), self.q2_model(x) |
|
|
|
def q1(self, state, action): |
|
x = torch.cat([state, action], dim=-1) |
|
return self.q1_model(x) |
|
|
|
def q_min(self, state, action): |
|
q1, q2 = self.forward(state, action) |
|
return torch.min(q1, q2) |
|
|
|
|
|
class Model(nn.Module): |
|
def __init__(self, state_dim, action_dim, hidden_size=256, time_dim=32): |
|
super(Model, self).__init__() |
|
|
|
self.time_mlp = nn.Sequential( |
|
SinusoidalPosEmb(time_dim), |
|
nn.Linear(time_dim, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, time_dim), |
|
) |
|
|
|
input_dim = state_dim + action_dim + time_dim |
|
self.layer = nn.Sequential(nn.Linear(input_dim, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, action_dim)) |
|
self.apply(init_weights) |
|
|
|
|
|
def forward(self, x, time, state): |
|
|
|
t = self.time_mlp(time) |
|
out = torch.cat([x, t, state], dim=-1) |
|
out = self.layer(out) |
|
|
|
return out |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, state_dim, action_dim, hidden_size=256): |
|
super(MLP, self).__init__() |
|
|
|
input_dim = state_dim |
|
self.mid_layer = nn.Sequential(nn.Linear(input_dim, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish(), |
|
nn.Linear(hidden_size, hidden_size), |
|
nn.Mish()) |
|
|
|
self.final_layer = nn.Linear(hidden_size, action_dim) |
|
|
|
self.apply(init_weights) |
|
|
|
def forward(self, state, eval=False): |
|
out = self.mid_layer(state) |
|
out = self.final_layer(out) |
|
|
|
if not eval: |
|
out += torch.randn_like(out) * 0.1 |
|
|
|
return out |
|
|
|
def loss(self, action, state): |
|
return F.mse_loss(self.forward(state), action, reduction='mean') |
|
|