00BER's picture
Added lunar lander files
85e4824
import torch
import numpy as np
import random
import torch.nn as nn
import copy
import time, datetime
import matplotlib.pyplot as plt
from collections import deque
from torch.utils.tensorboard import SummaryWriter
class DQNet(nn.Module):
"""mini cnn structure"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.online = nn.Sequential(
nn.Linear(input_dim, 100),
nn.ReLU(),
nn.Linear(100, 120),
nn.ReLU(),
nn.Linear(120, output_dim),
)
self.target = copy.deepcopy(self.online)
# Q_target parameters are frozen.
for p in self.target.parameters():
p.requires_grad = False
def forward(self, input, model):
if model == "online":
return self.online(input)
elif model == "target":
return self.target(input)
class MetricLogger:
def __init__(self, save_dir):
self.writer = SummaryWriter(log_dir=save_dir)
self.save_log = save_dir / "log"
with open(self.save_log, "w") as f:
f.write(
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
f"{'TimeDelta':>15}{'Time':>20}\n"
)
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
self.ep_lengths_plot = save_dir / "length_plot.jpg"
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
# History metrics
self.ep_rewards = []
self.ep_lengths = []
self.ep_avg_losses = []
self.ep_avg_qs = []
# Moving averages, added for every call to record()
self.moving_avg_ep_rewards = []
self.moving_avg_ep_lengths = []
self.moving_avg_ep_avg_losses = []
self.moving_avg_ep_avg_qs = []
# Current episode metric
self.init_episode()
# Timing
self.record_time = time.time()
def log_step(self, reward, loss, q):
self.curr_ep_reward += reward
self.curr_ep_length += 1
if loss:
self.curr_ep_loss += loss
self.curr_ep_q += q
self.curr_ep_loss_length += 1
def log_episode(self, episode_number):
"Mark end of episode"
self.ep_rewards.append(self.curr_ep_reward)
self.ep_lengths.append(self.curr_ep_length)
if self.curr_ep_loss_length == 0:
ep_avg_loss = 0
ep_avg_q = 0
else:
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
self.ep_avg_losses.append(ep_avg_loss)
self.ep_avg_qs.append(ep_avg_q)
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
self.writer.flush()
self.init_episode()
def init_episode(self):
self.curr_ep_reward = 0.0
self.curr_ep_length = 0
self.curr_ep_loss = 0.0
self.curr_ep_q = 0.0
self.curr_ep_loss_length = 0
def record(self, episode, epsilon, step):
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
self.moving_avg_ep_rewards.append(mean_ep_reward)
self.moving_avg_ep_lengths.append(mean_ep_length)
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
self.moving_avg_ep_avg_qs.append(mean_ep_q)
last_record_time = self.record_time
self.record_time = time.time()
time_since_last_record = np.round(self.record_time - last_record_time, 3)
print(
f"Episode {episode} - "
f"Step {step} - "
f"Epsilon {epsilon} - "
f"Mean Reward {mean_ep_reward} - "
f"Mean Length {mean_ep_length} - "
f"Mean Loss {mean_ep_loss} - "
f"Mean Q Value {mean_ep_q} - "
f"Time Delta {time_since_last_record} - "
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
)
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
self.writer.add_scalar("Epsilon value", epsilon, episode)
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
self.writer.flush()
with open(self.save_log, "a") as f:
f.write(
f"{episode:8d}{step:8d}{epsilon:10.3f}"
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
f"{time_since_last_record:15.3f}"
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
)
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
plt.plot(getattr(self, f"moving_avg_{metric}"))
plt.savefig(getattr(self, f"{metric}_plot"))
plt.clf()
class DQNAgent:
def __init__(self,
state_dim,
action_dim,
save_dir,
checkpoint=None,
learning_rate=0.00025,
max_memory_size=100000,
batch_size=32,
exploration_rate=1,
exploration_rate_decay=0.9999999,
exploration_rate_min=0.1,
training_frequency=1,
learning_starts=1000,
target_network_sync_frequency=500,
reset_exploration_rate=False,
save_frequency=100000,
gamma=0.9,
load_replay_buffer=True):
self.state_dim = state_dim
self.action_dim = action_dim
self.max_memory_size = max_memory_size
self.memory = deque(maxlen=max_memory_size)
self.batch_size = batch_size
self.exploration_rate = exploration_rate
self.exploration_rate_decay = exploration_rate_decay
self.exploration_rate_min = exploration_rate_min
self.gamma = gamma
self.curr_step = 0
self.learning_starts = learning_starts # min. experiences before training
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
self.save_every = save_frequency # no. of experiences between saving the network
self.save_dir = save_dir
self.use_cuda = torch.cuda.is_available()
self.net = DQNet(self.state_dim, self.action_dim).float()
if self.use_cuda:
self.net = self.net.to(device='cuda')
if checkpoint:
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
self.loss_fn = torch.nn.SmoothL1Loss()
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
# self.loss_fn = torch.nn.MSELoss()
def act(self, state):
"""
Given a state, choose an epsilon-greedy action and update value of step.
Inputs:
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
Outputs:
action_idx (int): An integer representing which action the agent will perform
"""
# EXPLORE
if np.random.rand() < self.exploration_rate:
action_idx = np.random.randint(self.action_dim)
# EXPLOIT
else:
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
state = state.unsqueeze(0)
action_values = self.net(state, model='online')
action_idx = torch.argmax(action_values, axis=1).item()
# decrease exploration_rate
self.exploration_rate *= self.exploration_rate_decay
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
# increment step
self.curr_step += 1
return action_idx
def cache(self, state, next_state, action, reward, done):
"""
Store the experience to self.memory (replay buffer)
Inputs:
state (LazyFrame),
next_state (LazyFrame),
action (int),
reward (float),
done(bool))
"""
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
self.memory.append( (state, next_state, action, reward, done,) )
def recall(self):
"""
Retrieve a batch of experiences from memory
"""
batch = random.sample(self.memory, self.batch_size)
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
def td_estimate(self, states, actions):
actions = actions.reshape(-1, 1)
predicted_qs = self.net(states, model='online')# Q_online(s,a)
predicted_qs = predicted_qs.gather(1, actions)
return predicted_qs
@torch.no_grad()
def td_target(self, rewards, next_states, dones):
rewards = rewards.reshape(-1, 1)
dones = dones.reshape(-1, 1)
target_qs = self.net(next_states, model='target')
target_qs = torch.max(target_qs, dim=1).values
target_qs = target_qs.reshape(-1, 1)
target_qs[dones] = 0.0
return (rewards + (self.gamma * target_qs))
def update_Q_online(self, td_estimate, td_target) :
loss = self.loss_fn(td_estimate.float(), td_target.float())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def sync_Q_target(self):
self.net.target.load_state_dict(self.net.online.state_dict())
def learn(self):
if self.curr_step % self.target_network_sync_frequency == 0:
self.sync_Q_target()
if self.curr_step % self.save_every == 0:
self.save()
if self.curr_step < self.learning_starts:
return None, None
if self.curr_step % self.training_frequency != 0:
return None, None
# Sample from memory
state, next_state, action, reward, done = self.recall()
# Get TD Estimate
td_est = self.td_estimate(state, action)
# Get TD Target
td_tgt = self.td_target(reward, next_state, done)
# Backpropagate loss through Q_online
loss = self.update_Q_online(td_est, td_tgt)
return (td_est.mean().item(), loss)
def save(self):
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
torch.save(
dict(
model=self.net.state_dict(),
exploration_rate=self.exploration_rate,
replay_memory=self.memory
),
save_path
)
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
if not load_path.exists():
raise ValueError(f"{load_path} does not exist")
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
exploration_rate = ckp.get('exploration_rate')
state_dict = ckp.get('model')
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
self.net.load_state_dict(state_dict)
if load_replay_buffer:
replay_memory = ckp.get('replay_memory')
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
self.memory = replay_memory if replay_memory else self.memory
if reset_exploration_rate:
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
else:
print(f"Setting exploration rate to {exploration_rate} not loaded.")
self.exploration_rate = exploration_rate
class DDQNAgent(DQNAgent):
@torch.no_grad()
def td_target(self, rewards, next_states, dones):
rewards = rewards.reshape(-1, 1)
dones = dones.reshape(-1, 1)
q_vals = self.net(next_states, model='online')
target_actions = torch.argmax(q_vals, axis=1)
target_actions = target_actions.reshape(-1, 1)
target_qs = self.net(next_states, model='target')
target_qs = target_qs.gather(1, target_actions)
target_qs = target_qs.reshape(-1, 1)
target_qs[dones] = 0.0
return (rewards + (self.gamma * target_qs))
class DuelingDQNet(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
self.feature_layer = nn.Sequential(
nn.Linear(input_dim, 150),
nn.ReLU(),
nn.Linear(150, 120),
nn.ReLU()
)
self.value_layer = nn.Sequential(
nn.Linear(120, 120),
nn.ReLU(),
nn.Linear(120, 1)
)
self.advantage_layer = nn.Sequential(
nn.Linear(120, 120),
nn.ReLU(),
nn.Linear(120, output_dim)
)
def forward(self, state):
feature_output = self.feature_layer(state)
# feature_output = feature_output.view(feature_output.size(0), -1)
value = self.value_layer(feature_output)
advantage = self.advantage_layer(feature_output)
q_value = value + (advantage - advantage.mean())
return q_value
class DuelingDQNAgent:
def __init__(self,
state_dim,
action_dim,
save_dir,
checkpoint=None,
learning_rate=0.00025,
max_memory_size=100000,
batch_size=32,
exploration_rate=1,
exploration_rate_decay=0.9999999,
exploration_rate_min=0.1,
training_frequency=1,
learning_starts=1000,
target_network_sync_frequency=500,
reset_exploration_rate=False,
save_frequency=100000,
gamma=0.9,
load_replay_buffer=True):
self.state_dim = state_dim
self.action_dim = action_dim
self.max_memory_size = max_memory_size
self.memory = deque(maxlen=max_memory_size)
self.batch_size = batch_size
self.exploration_rate = exploration_rate
self.exploration_rate_decay = exploration_rate_decay
self.exploration_rate_min = exploration_rate_min
self.gamma = gamma
self.curr_step = 0
self.learning_starts = learning_starts # min. experiences before training
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
self.save_every = save_frequency # no. of experiences between saving the network
self.save_dir = save_dir
self.use_cuda = torch.cuda.is_available()
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
self.target_net = copy.deepcopy(self.online_net)
# Q_target parameters are frozen.
for p in self.target_net.parameters():
p.requires_grad = False
if self.use_cuda:
self.online_net = self.online_net(device='cuda')
self.target_net = self.target_net(device='cuda')
if checkpoint:
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
self.loss_fn = torch.nn.SmoothL1Loss()
# self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
# self.loss_fn = torch.nn.MSELoss()
def act(self, state):
"""
Given a state, choose an epsilon-greedy action and update value of step.
Inputs:
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
Outputs:
action_idx (int): An integer representing which action the agent will perform
"""
# EXPLORE
if np.random.rand() < self.exploration_rate:
action_idx = np.random.randint(self.action_dim)
# EXPLOIT
else:
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
state = state.unsqueeze(0)
action_values = self.online_net(state)
action_idx = torch.argmax(action_values, axis=1).item()
# decrease exploration_rate
self.exploration_rate *= self.exploration_rate_decay
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
# increment step
self.curr_step += 1
return action_idx
def cache(self, state, next_state, action, reward, done):
"""
Store the experience to self.memory (replay buffer)
Inputs:
state (LazyFrame),
next_state (LazyFrame),
action (int),
reward (float),
done(bool))
"""
print("####################################")
print(state)
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
self.memory.append( (state, next_state, action, reward, done,) )
def recall(self):
"""
Retrieve a batch of experiences from memory
"""
batch = random.sample(self.memory, self.batch_size)
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
def td_estimate(self, states, actions):
actions = actions.reshape(-1, 1)
predicted_qs = self.online_net(states)# Q_online(s,a)
predicted_qs = predicted_qs.gather(1, actions)
return predicted_qs
@torch.no_grad()
def td_target(self, rewards, next_states, dones):
rewards = rewards.reshape(-1, 1)
dones = dones.reshape(-1, 1)
target_qs = self.target_net.forward(next_states)
target_qs = torch.max(target_qs, dim=1).values
target_qs = target_qs.reshape(-1, 1)
target_qs[dones] = 0.0
return (rewards + (self.gamma * target_qs))
def update_Q_online(self, td_estimate, td_target) :
loss = self.loss_fn(td_estimate.float(), td_target.float())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def sync_Q_target(self):
self.target_net.load_state_dict(self.online_net.state_dict())
def learn(self):
if self.curr_step % self.target_network_sync_frequency == 0:
self.sync_Q_target()
if self.curr_step % self.save_every == 0:
self.save()
if self.curr_step < self.learning_starts:
return None, None
if self.curr_step % self.training_frequency != 0:
return None, None
# Sample from memory
state, next_state, action, reward, done = self.recall()
# Get TD Estimate
td_est = self.td_estimate(state, action)
# Get TD Target
td_tgt = self.td_target(reward, next_state, done)
# Backpropagate loss through Q_online
loss = self.update_Q_online(td_est, td_tgt)
return (td_est.mean().item(), loss)
def save(self):
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
torch.save(
dict(
model=self.online_net.state_dict(),
exploration_rate=self.exploration_rate,
replay_memory=self.memory
),
save_path
)
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
if not load_path.exists():
raise ValueError(f"{load_path} does not exist")
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
exploration_rate = ckp.get('exploration_rate')
state_dict = ckp.get('model')
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
self.online_net.load_state_dict(state_dict)
self.target_net = copy.deepcopy(self.online_net)
self.sync_Q_target()
if load_replay_buffer:
replay_memory = ckp.get('replay_memory')
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
self.memory = replay_memory if replay_memory else self.memory
if reset_exploration_rate:
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
else:
print(f"Setting exploration rate to {exploration_rate} not loaded.")
self.exploration_rate = exploration_rate
class DuelingDDQNAgent(DuelingDQNAgent):
@torch.no_grad()
def td_target(self, rewards, next_states, dones):
rewards = rewards.reshape(-1, 1)
dones = dones.reshape(-1, 1)
q_vals = self.online_net.forward(next_states)
target_actions = torch.argmax(q_vals, axis=1)
target_actions = target_actions.reshape(-1, 1)
target_qs = self.target_net.forward(next_states)
target_qs = target_qs.gather(1, target_actions)
target_qs = target_qs.reshape(-1, 1)
target_qs[dones] = 0.0
return (rewards + (self.gamma * target_qs))
class DQNAgentWithStepDecay:
def __init__(self,
state_dim,
action_dim,
save_dir,
checkpoint=None,
learning_rate=0.00025,
max_memory_size=100000,
batch_size=32,
exploration_rate=1,
exploration_rate_decay=0.9999999,
exploration_rate_min=0.1,
training_frequency=1,
learning_starts=1000,
target_network_sync_frequency=500,
reset_exploration_rate=False,
save_frequency=100000,
gamma=0.9,
load_replay_buffer=True):
self.state_dim = state_dim
self.action_dim = action_dim
self.max_memory_size = max_memory_size
self.memory = deque(maxlen=max_memory_size)
self.batch_size = batch_size
self.exploration_rate = exploration_rate
self.exploration_rate_decay = exploration_rate_decay
self.exploration_rate_min = exploration_rate_min
self.gamma = gamma
self.curr_step = 0
self.learning_starts = learning_starts # min. experiences before training
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
self.save_every = save_frequency # no. of experiences between saving the network
self.save_dir = save_dir
self.use_cuda = torch.cuda.is_available()
self.net = DQNet(self.state_dim, self.action_dim).float()
if self.use_cuda:
self.net = self.net.to(device='cuda')
if checkpoint:
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
self.loss_fn = torch.nn.SmoothL1Loss()
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
# self.loss_fn = torch.nn.MSELoss()
def act(self, state):
"""
Given a state, choose an epsilon-greedy action and update value of step.
Inputs:
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
Outputs:
action_idx (int): An integer representing which action the agent will perform
"""
# EXPLORE
if np.random.rand() < self.exploration_rate:
action_idx = np.random.randint(self.action_dim)
# EXPLOIT
else:
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
state = state.unsqueeze(0)
action_values = self.net(state, model='online')
action_idx = torch.argmax(action_values, axis=1).item()
# decrease exploration_rate
self.exploration_rate *= self.exploration_rate_decay
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
# increment step
self.curr_step += 1
return action_idx
def cache(self, state, next_state, action, reward, done):
"""
Store the experience to self.memory (replay buffer)
Inputs:
state (LazyFrame),
next_state (LazyFrame),
action (int),
reward (float),
done(bool))
"""
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
self.memory.append( (state, next_state, action, reward, done) )
def recall(self):
"""
Retrieve a batch of experiences from memory
"""
batch = random.sample(self.memory, self.batch_size)
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
def td_estimate(self, states, actions):
actions = actions.reshape(-1, 1)
predicted_qs = self.net(states, model='online')# Q_online(s,a)
predicted_qs = predicted_qs.gather(1, actions)
return predicted_qs
@torch.no_grad()
def td_target(self, rewards, next_states, dones):
rewards = rewards.reshape(-1, 1)
dones = dones.reshape(-1, 1)
target_qs = self.net(next_states, model='target')
target_qs = torch.max(target_qs, dim=1).values
target_qs = target_qs.reshape(-1, 1)
target_qs[dones] = 0.0
val = self.gamma * target_qs
return (rewards + val)
def update_Q_online(self, td_estimate, td_target) :
loss = self.loss_fn(td_estimate.float(), td_target.float())
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def sync_Q_target(self):
self.net.target.load_state_dict(self.net.online.state_dict())
def learn(self):
if self.curr_step % self.target_network_sync_frequency == 0:
self.sync_Q_target()
if self.curr_step % self.save_every == 0:
self.save()
if self.curr_step < self.learning_starts:
return None, None
if self.curr_step % self.training_frequency != 0:
return None, None
# Sample from memory
state, next_state, action, reward, done = self.recall()
# Get TD Estimate
td_est = self.td_estimate(state, action)
# Get TD Target
td_tgt = self.td_target(reward, next_state, done)
# Backpropagate loss through Q_online
loss = self.update_Q_online(td_est, td_tgt)
return (td_est.mean().item(), loss)
def save(self):
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
torch.save(
dict(
model=self.net.state_dict(),
exploration_rate=self.exploration_rate,
replay_memory=self.memory
),
save_path
)
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
if not load_path.exists():
raise ValueError(f"{load_path} does not exist")
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
exploration_rate = ckp.get('exploration_rate')
state_dict = ckp.get('model')
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
self.net.load_state_dict(state_dict)
if load_replay_buffer:
replay_memory = ckp.get('replay_memory')
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
self.memory = replay_memory if replay_memory else self.memory
if reset_exploration_rate:
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
else:
print(f"Setting exploration rate to {exploration_rate} not loaded.")
self.exploration_rate = exploration_rate