DIPO / agent /model.py
Wyatt-Huang's picture
Upload 10 files
f761808 verified
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from agent.helpers import SinusoidalPosEmb, init_weights
class Critic(nn.Module):
def __init__(self, state_dim, action_dim, hidden_dim=256):
super(Critic, self).__init__()
self.q1_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, 1))
self.q2_model = nn.Sequential(nn.Linear(state_dim + action_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, hidden_dim),
nn.Mish(),
nn.Linear(hidden_dim, 1))
self.apply(init_weights)
def forward(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1_model(x), self.q2_model(x)
def q1(self, state, action):
x = torch.cat([state, action], dim=-1)
return self.q1_model(x)
def q_min(self, state, action):
q1, q2 = self.forward(state, action)
return torch.min(q1, q2)
class Model(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=256, time_dim=32):
super(Model, self).__init__()
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(time_dim),
nn.Linear(time_dim, hidden_size),
nn.Mish(),
nn.Linear(hidden_size, time_dim),
)
input_dim = state_dim + action_dim + time_dim
self.layer = nn.Sequential(nn.Linear(input_dim, hidden_size),
nn.Mish(),
nn.Linear(hidden_size, hidden_size),
nn.Mish(),
nn.Linear(hidden_size, hidden_size),
nn.Mish(),
nn.Linear(hidden_size, action_dim))
self.apply(init_weights)
def forward(self, x, time, state):
t = self.time_mlp(time)
out = torch.cat([x, t, state], dim=-1)
out = self.layer(out)
return out
class MLP(nn.Module):
def __init__(self, state_dim, action_dim, hidden_size=256):
super(MLP, self).__init__()
input_dim = state_dim
self.mid_layer = nn.Sequential(nn.Linear(input_dim, hidden_size),
nn.Mish(),
nn.Linear(hidden_size, hidden_size),
nn.Mish(),
nn.Linear(hidden_size, hidden_size),
nn.Mish())
self.final_layer = nn.Linear(hidden_size, action_dim)
self.apply(init_weights)
def forward(self, state, eval=False):
out = self.mid_layer(state)
out = self.final_layer(out)
if not eval:
out += torch.randn_like(out) * 0.1
return out
def loss(self, action, state):
return F.mse_loss(self.forward(state), action, reduction='mean')