|
import torch |
|
import numpy as np |
|
import random |
|
import torch.nn as nn |
|
import copy |
|
import time, datetime |
|
import matplotlib.pyplot as plt |
|
from collections import deque |
|
from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
|
class DQNet(nn.Module): |
|
"""mini cnn structure""" |
|
|
|
def __init__(self, input_dim, output_dim): |
|
super().__init__() |
|
|
|
self.online = nn.Sequential( |
|
nn.Linear(input_dim, 100), |
|
nn.ReLU(), |
|
nn.Linear(100, 120), |
|
nn.ReLU(), |
|
nn.Linear(120, output_dim), |
|
) |
|
|
|
|
|
self.target = copy.deepcopy(self.online) |
|
|
|
|
|
for p in self.target.parameters(): |
|
p.requires_grad = False |
|
|
|
def forward(self, input, model): |
|
if model == "online": |
|
return self.online(input) |
|
elif model == "target": |
|
return self.target(input) |
|
|
|
|
|
|
|
class MetricLogger: |
|
def __init__(self, save_dir): |
|
self.writer = SummaryWriter(log_dir=save_dir) |
|
self.save_log = save_dir / "log" |
|
with open(self.save_log, "w") as f: |
|
f.write( |
|
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}" |
|
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}" |
|
f"{'TimeDelta':>15}{'Time':>20}\n" |
|
) |
|
self.ep_rewards_plot = save_dir / "reward_plot.jpg" |
|
self.ep_lengths_plot = save_dir / "length_plot.jpg" |
|
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg" |
|
self.ep_avg_qs_plot = save_dir / "q_plot.jpg" |
|
|
|
|
|
self.ep_rewards = [] |
|
self.ep_lengths = [] |
|
self.ep_avg_losses = [] |
|
self.ep_avg_qs = [] |
|
|
|
|
|
self.moving_avg_ep_rewards = [] |
|
self.moving_avg_ep_lengths = [] |
|
self.moving_avg_ep_avg_losses = [] |
|
self.moving_avg_ep_avg_qs = [] |
|
|
|
|
|
self.init_episode() |
|
|
|
|
|
self.record_time = time.time() |
|
|
|
def log_step(self, reward, loss, q): |
|
self.curr_ep_reward += reward |
|
self.curr_ep_length += 1 |
|
if loss: |
|
self.curr_ep_loss += loss |
|
self.curr_ep_q += q |
|
self.curr_ep_loss_length += 1 |
|
|
|
def log_episode(self, episode_number): |
|
"Mark end of episode" |
|
self.ep_rewards.append(self.curr_ep_reward) |
|
self.ep_lengths.append(self.curr_ep_length) |
|
if self.curr_ep_loss_length == 0: |
|
ep_avg_loss = 0 |
|
ep_avg_q = 0 |
|
else: |
|
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5) |
|
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5) |
|
self.ep_avg_losses.append(ep_avg_loss) |
|
self.ep_avg_qs.append(ep_avg_q) |
|
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number) |
|
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number) |
|
self.writer.flush() |
|
self.init_episode() |
|
|
|
def init_episode(self): |
|
self.curr_ep_reward = 0.0 |
|
self.curr_ep_length = 0 |
|
self.curr_ep_loss = 0.0 |
|
self.curr_ep_q = 0.0 |
|
self.curr_ep_loss_length = 0 |
|
|
|
def record(self, episode, epsilon, step): |
|
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3) |
|
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3) |
|
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3) |
|
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3) |
|
self.moving_avg_ep_rewards.append(mean_ep_reward) |
|
self.moving_avg_ep_lengths.append(mean_ep_length) |
|
self.moving_avg_ep_avg_losses.append(mean_ep_loss) |
|
self.moving_avg_ep_avg_qs.append(mean_ep_q) |
|
|
|
last_record_time = self.record_time |
|
self.record_time = time.time() |
|
time_since_last_record = np.round(self.record_time - last_record_time, 3) |
|
|
|
print( |
|
f"Episode {episode} - " |
|
f"Step {step} - " |
|
f"Epsilon {epsilon} - " |
|
f"Mean Reward {mean_ep_reward} - " |
|
f"Mean Length {mean_ep_length} - " |
|
f"Mean Loss {mean_ep_loss} - " |
|
f"Mean Q Value {mean_ep_q} - " |
|
f"Time Delta {time_since_last_record} - " |
|
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}" |
|
) |
|
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode) |
|
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode) |
|
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode) |
|
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode) |
|
self.writer.add_scalar("Epsilon value", epsilon, episode) |
|
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode) |
|
self.writer.flush() |
|
with open(self.save_log, "a") as f: |
|
f.write( |
|
f"{episode:8d}{step:8d}{epsilon:10.3f}" |
|
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}" |
|
f"{time_since_last_record:15.3f}" |
|
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n" |
|
) |
|
|
|
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]: |
|
plt.plot(getattr(self, f"moving_avg_{metric}")) |
|
plt.savefig(getattr(self, f"{metric}_plot")) |
|
plt.clf() |
|
|
|
|
|
class DQNAgent: |
|
def __init__(self, |
|
state_dim, |
|
action_dim, |
|
save_dir, |
|
checkpoint=None, |
|
learning_rate=0.00025, |
|
max_memory_size=100000, |
|
batch_size=32, |
|
exploration_rate=1, |
|
exploration_rate_decay=0.9999999, |
|
exploration_rate_min=0.1, |
|
training_frequency=1, |
|
learning_starts=1000, |
|
target_network_sync_frequency=500, |
|
reset_exploration_rate=False, |
|
save_frequency=100000, |
|
gamma=0.9, |
|
load_replay_buffer=True): |
|
self.state_dim = state_dim |
|
self.action_dim = action_dim |
|
self.max_memory_size = max_memory_size |
|
self.memory = deque(maxlen=max_memory_size) |
|
self.batch_size = batch_size |
|
|
|
self.exploration_rate = exploration_rate |
|
self.exploration_rate_decay = exploration_rate_decay |
|
self.exploration_rate_min = exploration_rate_min |
|
self.gamma = gamma |
|
|
|
self.curr_step = 0 |
|
self.learning_starts = learning_starts |
|
|
|
self.training_frequency = training_frequency |
|
self.target_network_sync_frequency = target_network_sync_frequency |
|
|
|
self.save_every = save_frequency |
|
self.save_dir = save_dir |
|
|
|
self.use_cuda = torch.cuda.is_available() |
|
|
|
self.net = DQNet(self.state_dim, self.action_dim).float() |
|
if self.use_cuda: |
|
self.net = self.net.to(device='cuda') |
|
if checkpoint: |
|
self.load(checkpoint, reset_exploration_rate, load_replay_buffer) |
|
|
|
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True) |
|
self.loss_fn = torch.nn.SmoothL1Loss() |
|
|
|
|
|
|
|
|
|
def act(self, state): |
|
""" |
|
Given a state, choose an epsilon-greedy action and update value of step. |
|
|
|
Inputs: |
|
state(LazyFrame): A single observation of the current state, dimension is (state_dim) |
|
Outputs: |
|
action_idx (int): An integer representing which action the agent will perform |
|
""" |
|
|
|
if np.random.rand() < self.exploration_rate: |
|
action_idx = np.random.randint(self.action_dim) |
|
|
|
|
|
else: |
|
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) |
|
state = state.unsqueeze(0) |
|
action_values = self.net(state, model='online') |
|
action_idx = torch.argmax(action_values, axis=1).item() |
|
|
|
|
|
|
|
self.exploration_rate *= self.exploration_rate_decay |
|
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate) |
|
|
|
|
|
self.curr_step += 1 |
|
return action_idx |
|
|
|
def cache(self, state, next_state, action, reward, done): |
|
""" |
|
Store the experience to self.memory (replay buffer) |
|
|
|
Inputs: |
|
state (LazyFrame), |
|
next_state (LazyFrame), |
|
action (int), |
|
reward (float), |
|
done(bool)) |
|
""" |
|
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) |
|
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state) |
|
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action]) |
|
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward]) |
|
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done]) |
|
|
|
self.memory.append( (state, next_state, action, reward, done,) ) |
|
|
|
|
|
def recall(self): |
|
""" |
|
Retrieve a batch of experiences from memory |
|
""" |
|
batch = random.sample(self.memory, self.batch_size) |
|
state, next_state, action, reward, done = map(torch.stack, zip(*batch)) |
|
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze() |
|
|
|
|
|
def td_estimate(self, states, actions): |
|
actions = actions.reshape(-1, 1) |
|
predicted_qs = self.net(states, model='online') |
|
predicted_qs = predicted_qs.gather(1, actions) |
|
return predicted_qs |
|
|
|
|
|
@torch.no_grad() |
|
def td_target(self, rewards, next_states, dones): |
|
rewards = rewards.reshape(-1, 1) |
|
dones = dones.reshape(-1, 1) |
|
target_qs = self.net(next_states, model='target') |
|
target_qs = torch.max(target_qs, dim=1).values |
|
target_qs = target_qs.reshape(-1, 1) |
|
target_qs[dones] = 0.0 |
|
return (rewards + (self.gamma * target_qs)) |
|
|
|
def update_Q_online(self, td_estimate, td_target) : |
|
loss = self.loss_fn(td_estimate.float(), td_target.float()) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
return loss.item() |
|
|
|
|
|
def sync_Q_target(self): |
|
self.net.target.load_state_dict(self.net.online.state_dict()) |
|
|
|
|
|
def learn(self): |
|
if self.curr_step % self.target_network_sync_frequency == 0: |
|
self.sync_Q_target() |
|
|
|
if self.curr_step % self.save_every == 0: |
|
self.save() |
|
|
|
if self.curr_step < self.learning_starts: |
|
return None, None |
|
|
|
if self.curr_step % self.training_frequency != 0: |
|
return None, None |
|
|
|
|
|
state, next_state, action, reward, done = self.recall() |
|
|
|
|
|
td_est = self.td_estimate(state, action) |
|
|
|
|
|
td_tgt = self.td_target(reward, next_state, done) |
|
|
|
|
|
|
|
loss = self.update_Q_online(td_est, td_tgt) |
|
|
|
return (td_est.mean().item(), loss) |
|
|
|
|
|
def save(self): |
|
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt" |
|
torch.save( |
|
dict( |
|
model=self.net.state_dict(), |
|
exploration_rate=self.exploration_rate, |
|
replay_memory=self.memory |
|
), |
|
save_path |
|
) |
|
|
|
print(f"Airstriker model saved to {save_path} at step {self.curr_step}") |
|
|
|
|
|
def load(self, load_path, reset_exploration_rate, load_replay_buffer): |
|
if not load_path.exists(): |
|
raise ValueError(f"{load_path} does not exist") |
|
|
|
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu')) |
|
exploration_rate = ckp.get('exploration_rate') |
|
state_dict = ckp.get('model') |
|
|
|
|
|
print(f"Loading model at {load_path} with exploration rate {exploration_rate}") |
|
self.net.load_state_dict(state_dict) |
|
|
|
if load_replay_buffer: |
|
replay_memory = ckp.get('replay_memory') |
|
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.") |
|
self.memory = replay_memory if replay_memory else self.memory |
|
|
|
if reset_exploration_rate: |
|
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}") |
|
else: |
|
print(f"Setting exploration rate to {exploration_rate} not loaded.") |
|
self.exploration_rate = exploration_rate |
|
|
|
|
|
class DDQNAgent(DQNAgent): |
|
@torch.no_grad() |
|
def td_target(self, rewards, next_states, dones): |
|
rewards = rewards.reshape(-1, 1) |
|
dones = dones.reshape(-1, 1) |
|
q_vals = self.net(next_states, model='online') |
|
target_actions = torch.argmax(q_vals, axis=1) |
|
target_actions = target_actions.reshape(-1, 1) |
|
|
|
target_qs = self.net(next_states, model='target') |
|
target_qs = target_qs.gather(1, target_actions) |
|
target_qs = target_qs.reshape(-1, 1) |
|
target_qs[dones] = 0.0 |
|
return (rewards + (self.gamma * target_qs)) |
|
|
|
|
|
class DuelingDQNet(nn.Module): |
|
def __init__(self, input_dim, output_dim): |
|
super().__init__() |
|
self.feature_layer = nn.Sequential( |
|
nn.Linear(input_dim, 150), |
|
nn.ReLU(), |
|
nn.Linear(150, 120), |
|
nn.ReLU() |
|
) |
|
|
|
self.value_layer = nn.Sequential( |
|
nn.Linear(120, 120), |
|
nn.ReLU(), |
|
nn.Linear(120, 1) |
|
) |
|
|
|
self.advantage_layer = nn.Sequential( |
|
nn.Linear(120, 120), |
|
nn.ReLU(), |
|
nn.Linear(120, output_dim) |
|
) |
|
|
|
def forward(self, state): |
|
feature_output = self.feature_layer(state) |
|
|
|
value = self.value_layer(feature_output) |
|
advantage = self.advantage_layer(feature_output) |
|
q_value = value + (advantage - advantage.mean()) |
|
|
|
return q_value |
|
|
|
|
|
class DuelingDQNAgent: |
|
def __init__(self, |
|
state_dim, |
|
action_dim, |
|
save_dir, |
|
checkpoint=None, |
|
learning_rate=0.00025, |
|
max_memory_size=100000, |
|
batch_size=32, |
|
exploration_rate=1, |
|
exploration_rate_decay=0.9999999, |
|
exploration_rate_min=0.1, |
|
training_frequency=1, |
|
learning_starts=1000, |
|
target_network_sync_frequency=500, |
|
reset_exploration_rate=False, |
|
save_frequency=100000, |
|
gamma=0.9, |
|
load_replay_buffer=True): |
|
self.state_dim = state_dim |
|
self.action_dim = action_dim |
|
self.max_memory_size = max_memory_size |
|
self.memory = deque(maxlen=max_memory_size) |
|
self.batch_size = batch_size |
|
|
|
self.exploration_rate = exploration_rate |
|
self.exploration_rate_decay = exploration_rate_decay |
|
self.exploration_rate_min = exploration_rate_min |
|
self.gamma = gamma |
|
|
|
self.curr_step = 0 |
|
self.learning_starts = learning_starts |
|
|
|
self.training_frequency = training_frequency |
|
self.target_network_sync_frequency = target_network_sync_frequency |
|
|
|
self.save_every = save_frequency |
|
self.save_dir = save_dir |
|
|
|
self.use_cuda = torch.cuda.is_available() |
|
|
|
|
|
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float() |
|
self.target_net = copy.deepcopy(self.online_net) |
|
|
|
for p in self.target_net.parameters(): |
|
p.requires_grad = False |
|
|
|
if self.use_cuda: |
|
self.online_net = self.online_net(device='cuda') |
|
self.target_net = self.target_net(device='cuda') |
|
if checkpoint: |
|
self.load(checkpoint, reset_exploration_rate, load_replay_buffer) |
|
|
|
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True) |
|
self.loss_fn = torch.nn.SmoothL1Loss() |
|
|
|
|
|
|
|
|
|
def act(self, state): |
|
""" |
|
Given a state, choose an epsilon-greedy action and update value of step. |
|
|
|
Inputs: |
|
state(LazyFrame): A single observation of the current state, dimension is (state_dim) |
|
Outputs: |
|
action_idx (int): An integer representing which action the agent will perform |
|
""" |
|
|
|
if np.random.rand() < self.exploration_rate: |
|
action_idx = np.random.randint(self.action_dim) |
|
|
|
|
|
else: |
|
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) |
|
state = state.unsqueeze(0) |
|
action_values = self.online_net(state) |
|
action_idx = torch.argmax(action_values, axis=1).item() |
|
|
|
|
|
self.exploration_rate *= self.exploration_rate_decay |
|
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate) |
|
|
|
|
|
self.curr_step += 1 |
|
return action_idx |
|
|
|
def cache(self, state, next_state, action, reward, done): |
|
""" |
|
Store the experience to self.memory (replay buffer) |
|
|
|
Inputs: |
|
state (LazyFrame), |
|
next_state (LazyFrame), |
|
action (int), |
|
reward (float), |
|
done(bool)) |
|
""" |
|
print("####################################") |
|
print(state) |
|
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) |
|
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state) |
|
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action]) |
|
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward]) |
|
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done]) |
|
|
|
self.memory.append( (state, next_state, action, reward, done,) ) |
|
|
|
|
|
def recall(self): |
|
""" |
|
Retrieve a batch of experiences from memory |
|
""" |
|
batch = random.sample(self.memory, self.batch_size) |
|
state, next_state, action, reward, done = map(torch.stack, zip(*batch)) |
|
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze() |
|
|
|
|
|
def td_estimate(self, states, actions): |
|
actions = actions.reshape(-1, 1) |
|
predicted_qs = self.online_net(states) |
|
predicted_qs = predicted_qs.gather(1, actions) |
|
return predicted_qs |
|
|
|
|
|
@torch.no_grad() |
|
def td_target(self, rewards, next_states, dones): |
|
rewards = rewards.reshape(-1, 1) |
|
dones = dones.reshape(-1, 1) |
|
target_qs = self.target_net.forward(next_states) |
|
target_qs = torch.max(target_qs, dim=1).values |
|
target_qs = target_qs.reshape(-1, 1) |
|
target_qs[dones] = 0.0 |
|
return (rewards + (self.gamma * target_qs)) |
|
|
|
def update_Q_online(self, td_estimate, td_target) : |
|
loss = self.loss_fn(td_estimate.float(), td_target.float()) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
return loss.item() |
|
|
|
|
|
def sync_Q_target(self): |
|
self.target_net.load_state_dict(self.online_net.state_dict()) |
|
|
|
|
|
def learn(self): |
|
if self.curr_step % self.target_network_sync_frequency == 0: |
|
self.sync_Q_target() |
|
|
|
if self.curr_step % self.save_every == 0: |
|
self.save() |
|
|
|
if self.curr_step < self.learning_starts: |
|
return None, None |
|
|
|
if self.curr_step % self.training_frequency != 0: |
|
return None, None |
|
|
|
|
|
state, next_state, action, reward, done = self.recall() |
|
|
|
|
|
td_est = self.td_estimate(state, action) |
|
|
|
|
|
td_tgt = self.td_target(reward, next_state, done) |
|
|
|
|
|
loss = self.update_Q_online(td_est, td_tgt) |
|
|
|
return (td_est.mean().item(), loss) |
|
|
|
|
|
def save(self): |
|
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt" |
|
torch.save( |
|
dict( |
|
model=self.online_net.state_dict(), |
|
exploration_rate=self.exploration_rate, |
|
replay_memory=self.memory |
|
), |
|
save_path |
|
) |
|
|
|
print(f"Airstriker model saved to {save_path} at step {self.curr_step}") |
|
|
|
|
|
def load(self, load_path, reset_exploration_rate, load_replay_buffer): |
|
if not load_path.exists(): |
|
raise ValueError(f"{load_path} does not exist") |
|
|
|
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu')) |
|
exploration_rate = ckp.get('exploration_rate') |
|
state_dict = ckp.get('model') |
|
|
|
|
|
print(f"Loading model at {load_path} with exploration rate {exploration_rate}") |
|
self.online_net.load_state_dict(state_dict) |
|
self.target_net = copy.deepcopy(self.online_net) |
|
self.sync_Q_target() |
|
|
|
if load_replay_buffer: |
|
replay_memory = ckp.get('replay_memory') |
|
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.") |
|
self.memory = replay_memory if replay_memory else self.memory |
|
|
|
if reset_exploration_rate: |
|
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}") |
|
else: |
|
print(f"Setting exploration rate to {exploration_rate} not loaded.") |
|
self.exploration_rate = exploration_rate |
|
|
|
|
|
|
|
|
|
class DuelingDDQNAgent(DuelingDQNAgent): |
|
@torch.no_grad() |
|
def td_target(self, rewards, next_states, dones): |
|
rewards = rewards.reshape(-1, 1) |
|
dones = dones.reshape(-1, 1) |
|
q_vals = self.online_net.forward(next_states) |
|
target_actions = torch.argmax(q_vals, axis=1) |
|
target_actions = target_actions.reshape(-1, 1) |
|
|
|
target_qs = self.target_net.forward(next_states) |
|
target_qs = target_qs.gather(1, target_actions) |
|
target_qs = target_qs.reshape(-1, 1) |
|
target_qs[dones] = 0.0 |
|
return (rewards + (self.gamma * target_qs)) |
|
|
|
|
|
|
|
class DQNAgentWithStepDecay: |
|
def __init__(self, |
|
state_dim, |
|
action_dim, |
|
save_dir, |
|
checkpoint=None, |
|
learning_rate=0.00025, |
|
max_memory_size=100000, |
|
batch_size=32, |
|
exploration_rate=1, |
|
exploration_rate_decay=0.9999999, |
|
exploration_rate_min=0.1, |
|
training_frequency=1, |
|
learning_starts=1000, |
|
target_network_sync_frequency=500, |
|
reset_exploration_rate=False, |
|
save_frequency=100000, |
|
gamma=0.9, |
|
load_replay_buffer=True): |
|
self.state_dim = state_dim |
|
self.action_dim = action_dim |
|
self.max_memory_size = max_memory_size |
|
self.memory = deque(maxlen=max_memory_size) |
|
self.batch_size = batch_size |
|
|
|
self.exploration_rate = exploration_rate |
|
self.exploration_rate_decay = exploration_rate_decay |
|
self.exploration_rate_min = exploration_rate_min |
|
self.gamma = gamma |
|
|
|
self.curr_step = 0 |
|
self.learning_starts = learning_starts |
|
|
|
self.training_frequency = training_frequency |
|
self.target_network_sync_frequency = target_network_sync_frequency |
|
|
|
self.save_every = save_frequency |
|
self.save_dir = save_dir |
|
|
|
self.use_cuda = torch.cuda.is_available() |
|
|
|
self.net = DQNet(self.state_dim, self.action_dim).float() |
|
if self.use_cuda: |
|
self.net = self.net.to(device='cuda') |
|
if checkpoint: |
|
self.load(checkpoint, reset_exploration_rate, load_replay_buffer) |
|
|
|
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True) |
|
self.loss_fn = torch.nn.SmoothL1Loss() |
|
|
|
|
|
|
|
|
|
def act(self, state): |
|
""" |
|
Given a state, choose an epsilon-greedy action and update value of step. |
|
|
|
Inputs: |
|
state(LazyFrame): A single observation of the current state, dimension is (state_dim) |
|
Outputs: |
|
action_idx (int): An integer representing which action the agent will perform |
|
""" |
|
|
|
if np.random.rand() < self.exploration_rate: |
|
action_idx = np.random.randint(self.action_dim) |
|
|
|
|
|
else: |
|
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) |
|
state = state.unsqueeze(0) |
|
action_values = self.net(state, model='online') |
|
action_idx = torch.argmax(action_values, axis=1).item() |
|
|
|
|
|
|
|
self.exploration_rate *= self.exploration_rate_decay |
|
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate) |
|
|
|
|
|
self.curr_step += 1 |
|
return action_idx |
|
|
|
def cache(self, state, next_state, action, reward, done): |
|
""" |
|
Store the experience to self.memory (replay buffer) |
|
|
|
Inputs: |
|
state (LazyFrame), |
|
next_state (LazyFrame), |
|
action (int), |
|
reward (float), |
|
done(bool)) |
|
""" |
|
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) |
|
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state) |
|
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action]) |
|
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward]) |
|
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done]) |
|
|
|
self.memory.append( (state, next_state, action, reward, done) ) |
|
|
|
|
|
def recall(self): |
|
""" |
|
Retrieve a batch of experiences from memory |
|
""" |
|
batch = random.sample(self.memory, self.batch_size) |
|
state, next_state, action, reward, done = map(torch.stack, zip(*batch)) |
|
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze() |
|
|
|
|
|
def td_estimate(self, states, actions): |
|
actions = actions.reshape(-1, 1) |
|
predicted_qs = self.net(states, model='online') |
|
predicted_qs = predicted_qs.gather(1, actions) |
|
return predicted_qs |
|
|
|
|
|
@torch.no_grad() |
|
def td_target(self, rewards, next_states, dones): |
|
rewards = rewards.reshape(-1, 1) |
|
dones = dones.reshape(-1, 1) |
|
target_qs = self.net(next_states, model='target') |
|
target_qs = torch.max(target_qs, dim=1).values |
|
target_qs = target_qs.reshape(-1, 1) |
|
target_qs[dones] = 0.0 |
|
val = self.gamma * target_qs |
|
return (rewards + val) |
|
|
|
def update_Q_online(self, td_estimate, td_target) : |
|
loss = self.loss_fn(td_estimate.float(), td_target.float()) |
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
self.optimizer.step() |
|
return loss.item() |
|
|
|
|
|
def sync_Q_target(self): |
|
self.net.target.load_state_dict(self.net.online.state_dict()) |
|
|
|
|
|
def learn(self): |
|
if self.curr_step % self.target_network_sync_frequency == 0: |
|
self.sync_Q_target() |
|
|
|
if self.curr_step % self.save_every == 0: |
|
self.save() |
|
|
|
if self.curr_step < self.learning_starts: |
|
return None, None |
|
|
|
if self.curr_step % self.training_frequency != 0: |
|
return None, None |
|
|
|
|
|
state, next_state, action, reward, done = self.recall() |
|
|
|
|
|
td_est = self.td_estimate(state, action) |
|
|
|
|
|
td_tgt = self.td_target(reward, next_state, done) |
|
|
|
|
|
|
|
loss = self.update_Q_online(td_est, td_tgt) |
|
|
|
return (td_est.mean().item(), loss) |
|
|
|
|
|
def save(self): |
|
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt" |
|
torch.save( |
|
dict( |
|
model=self.net.state_dict(), |
|
exploration_rate=self.exploration_rate, |
|
replay_memory=self.memory |
|
), |
|
save_path |
|
) |
|
|
|
print(f"Airstriker model saved to {save_path} at step {self.curr_step}") |
|
|
|
|
|
def load(self, load_path, reset_exploration_rate, load_replay_buffer): |
|
if not load_path.exists(): |
|
raise ValueError(f"{load_path} does not exist") |
|
|
|
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu')) |
|
exploration_rate = ckp.get('exploration_rate') |
|
state_dict = ckp.get('model') |
|
|
|
|
|
print(f"Loading model at {load_path} with exploration rate {exploration_rate}") |
|
self.net.load_state_dict(state_dict) |
|
|
|
if load_replay_buffer: |
|
replay_memory = ckp.get('replay_memory') |
|
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.") |
|
self.memory = replay_memory if replay_memory else self.memory |
|
|
|
if reset_exploration_rate: |
|
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}") |
|
else: |
|
print(f"Setting exploration rate to {exploration_rate} not loaded.") |
|
self.exploration_rate = exploration_rate |
|
|
|
|
|
|