import torch import numpy as np import random import torch.nn as nn import copy import time, datetime import matplotlib.pyplot as plt from collections import deque from torch.utils.tensorboard import SummaryWriter class DQNet(nn.Module): """mini cnn structure""" def __init__(self, input_dim, output_dim): super().__init__() self.online = nn.Sequential( nn.Linear(input_dim, 100), nn.ReLU(), nn.Linear(100, 120), nn.ReLU(), nn.Linear(120, output_dim), ) self.target = copy.deepcopy(self.online) # Q_target parameters are frozen. for p in self.target.parameters(): p.requires_grad = False def forward(self, input, model): if model == "online": return self.online(input) elif model == "target": return self.target(input) class MetricLogger: def __init__(self, save_dir): self.writer = SummaryWriter(log_dir=save_dir) self.save_log = save_dir / "log" with open(self.save_log, "w") as f: f.write( f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}" f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}" f"{'TimeDelta':>15}{'Time':>20}\n" ) self.ep_rewards_plot = save_dir / "reward_plot.jpg" self.ep_lengths_plot = save_dir / "length_plot.jpg" self.ep_avg_losses_plot = save_dir / "loss_plot.jpg" self.ep_avg_qs_plot = save_dir / "q_plot.jpg" # History metrics self.ep_rewards = [] self.ep_lengths = [] self.ep_avg_losses = [] self.ep_avg_qs = [] # Moving averages, added for every call to record() self.moving_avg_ep_rewards = [] self.moving_avg_ep_lengths = [] self.moving_avg_ep_avg_losses = [] self.moving_avg_ep_avg_qs = [] # Current episode metric self.init_episode() # Timing self.record_time = time.time() def log_step(self, reward, loss, q): self.curr_ep_reward += reward self.curr_ep_length += 1 if loss: self.curr_ep_loss += loss self.curr_ep_q += q self.curr_ep_loss_length += 1 def log_episode(self, episode_number): "Mark end of episode" self.ep_rewards.append(self.curr_ep_reward) self.ep_lengths.append(self.curr_ep_length) if self.curr_ep_loss_length == 0: ep_avg_loss = 0 ep_avg_q = 0 else: ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5) ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5) self.ep_avg_losses.append(ep_avg_loss) self.ep_avg_qs.append(ep_avg_q) self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number) self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number) self.writer.flush() self.init_episode() def init_episode(self): self.curr_ep_reward = 0.0 self.curr_ep_length = 0 self.curr_ep_loss = 0.0 self.curr_ep_q = 0.0 self.curr_ep_loss_length = 0 def record(self, episode, epsilon, step): mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3) mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3) mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3) mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3) self.moving_avg_ep_rewards.append(mean_ep_reward) self.moving_avg_ep_lengths.append(mean_ep_length) self.moving_avg_ep_avg_losses.append(mean_ep_loss) self.moving_avg_ep_avg_qs.append(mean_ep_q) last_record_time = self.record_time self.record_time = time.time() time_since_last_record = np.round(self.record_time - last_record_time, 3) print( f"Episode {episode} - " f"Step {step} - " f"Epsilon {epsilon} - " f"Mean Reward {mean_ep_reward} - " f"Mean Length {mean_ep_length} - " f"Mean Loss {mean_ep_loss} - " f"Mean Q Value {mean_ep_q} - " f"Time Delta {time_since_last_record} - " f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}" ) self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode) self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode) self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode) self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode) self.writer.add_scalar("Epsilon value", epsilon, episode) self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode) self.writer.flush() with open(self.save_log, "a") as f: f.write( f"{episode:8d}{step:8d}{epsilon:10.3f}" f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}" f"{time_since_last_record:15.3f}" f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n" ) for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]: plt.plot(getattr(self, f"moving_avg_{metric}")) plt.savefig(getattr(self, f"{metric}_plot")) plt.clf() class DQNAgent: def __init__(self, state_dim, action_dim, save_dir, checkpoint=None, learning_rate=0.00025, max_memory_size=100000, batch_size=32, exploration_rate=1, exploration_rate_decay=0.9999999, exploration_rate_min=0.1, training_frequency=1, learning_starts=1000, target_network_sync_frequency=500, reset_exploration_rate=False, save_frequency=100000, gamma=0.9, load_replay_buffer=True): self.state_dim = state_dim self.action_dim = action_dim self.max_memory_size = max_memory_size self.memory = deque(maxlen=max_memory_size) self.batch_size = batch_size self.exploration_rate = exploration_rate self.exploration_rate_decay = exploration_rate_decay self.exploration_rate_min = exploration_rate_min self.gamma = gamma self.curr_step = 0 self.learning_starts = learning_starts # min. experiences before training self.training_frequency = training_frequency # no. of experiences between updates to Q_online self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync self.save_every = save_frequency # no. of experiences between saving the network self.save_dir = save_dir self.use_cuda = torch.cuda.is_available() self.net = DQNet(self.state_dim, self.action_dim).float() if self.use_cuda: self.net = self.net.to(device='cuda') if checkpoint: self.load(checkpoint, reset_exploration_rate, load_replay_buffer) self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True) self.loss_fn = torch.nn.SmoothL1Loss() # self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate) # self.loss_fn = torch.nn.MSELoss() def act(self, state): """ Given a state, choose an epsilon-greedy action and update value of step. Inputs: state(LazyFrame): A single observation of the current state, dimension is (state_dim) Outputs: action_idx (int): An integer representing which action the agent will perform """ # EXPLORE if np.random.rand() < self.exploration_rate: action_idx = np.random.randint(self.action_dim) # EXPLOIT else: state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) state = state.unsqueeze(0) action_values = self.net(state, model='online') action_idx = torch.argmax(action_values, axis=1).item() # decrease exploration_rate self.exploration_rate *= self.exploration_rate_decay self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate) # increment step self.curr_step += 1 return action_idx def cache(self, state, next_state, action, reward, done): """ Store the experience to self.memory (replay buffer) Inputs: state (LazyFrame), next_state (LazyFrame), action (int), reward (float), done(bool)) """ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state) action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action]) reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward]) done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done]) self.memory.append( (state, next_state, action, reward, done,) ) def recall(self): """ Retrieve a batch of experiences from memory """ batch = random.sample(self.memory, self.batch_size) state, next_state, action, reward, done = map(torch.stack, zip(*batch)) return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze() def td_estimate(self, states, actions): actions = actions.reshape(-1, 1) predicted_qs = self.net(states, model='online')# Q_online(s,a) predicted_qs = predicted_qs.gather(1, actions) return predicted_qs @torch.no_grad() def td_target(self, rewards, next_states, dones): rewards = rewards.reshape(-1, 1) dones = dones.reshape(-1, 1) target_qs = self.net(next_states, model='target') target_qs = torch.max(target_qs, dim=1).values target_qs = target_qs.reshape(-1, 1) target_qs[dones] = 0.0 return (rewards + (self.gamma * target_qs)) def update_Q_online(self, td_estimate, td_target) : loss = self.loss_fn(td_estimate.float(), td_target.float()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def sync_Q_target(self): self.net.target.load_state_dict(self.net.online.state_dict()) def learn(self): if self.curr_step % self.target_network_sync_frequency == 0: self.sync_Q_target() if self.curr_step % self.save_every == 0: self.save() if self.curr_step < self.learning_starts: return None, None if self.curr_step % self.training_frequency != 0: return None, None # Sample from memory state, next_state, action, reward, done = self.recall() # Get TD Estimate td_est = self.td_estimate(state, action) # Get TD Target td_tgt = self.td_target(reward, next_state, done) # Backpropagate loss through Q_online loss = self.update_Q_online(td_est, td_tgt) return (td_est.mean().item(), loss) def save(self): save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt" torch.save( dict( model=self.net.state_dict(), exploration_rate=self.exploration_rate, replay_memory=self.memory ), save_path ) print(f"Airstriker model saved to {save_path} at step {self.curr_step}") def load(self, load_path, reset_exploration_rate, load_replay_buffer): if not load_path.exists(): raise ValueError(f"{load_path} does not exist") ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu')) exploration_rate = ckp.get('exploration_rate') state_dict = ckp.get('model') print(f"Loading model at {load_path} with exploration rate {exploration_rate}") self.net.load_state_dict(state_dict) if load_replay_buffer: replay_memory = ckp.get('replay_memory') print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.") self.memory = replay_memory if replay_memory else self.memory if reset_exploration_rate: print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}") else: print(f"Setting exploration rate to {exploration_rate} not loaded.") self.exploration_rate = exploration_rate class DDQNAgent(DQNAgent): @torch.no_grad() def td_target(self, rewards, next_states, dones): rewards = rewards.reshape(-1, 1) dones = dones.reshape(-1, 1) q_vals = self.net(next_states, model='online') target_actions = torch.argmax(q_vals, axis=1) target_actions = target_actions.reshape(-1, 1) target_qs = self.net(next_states, model='target') target_qs = target_qs.gather(1, target_actions) target_qs = target_qs.reshape(-1, 1) target_qs[dones] = 0.0 return (rewards + (self.gamma * target_qs)) class DuelingDQNet(nn.Module): def __init__(self, input_dim, output_dim): super().__init__() self.feature_layer = nn.Sequential( nn.Linear(input_dim, 150), nn.ReLU(), nn.Linear(150, 120), nn.ReLU() ) self.value_layer = nn.Sequential( nn.Linear(120, 120), nn.ReLU(), nn.Linear(120, 1) ) self.advantage_layer = nn.Sequential( nn.Linear(120, 120), nn.ReLU(), nn.Linear(120, output_dim) ) def forward(self, state): feature_output = self.feature_layer(state) # feature_output = feature_output.view(feature_output.size(0), -1) value = self.value_layer(feature_output) advantage = self.advantage_layer(feature_output) q_value = value + (advantage - advantage.mean()) return q_value class DuelingDQNAgent: def __init__(self, state_dim, action_dim, save_dir, checkpoint=None, learning_rate=0.00025, max_memory_size=100000, batch_size=32, exploration_rate=1, exploration_rate_decay=0.9999999, exploration_rate_min=0.1, training_frequency=1, learning_starts=1000, target_network_sync_frequency=500, reset_exploration_rate=False, save_frequency=100000, gamma=0.9, load_replay_buffer=True): self.state_dim = state_dim self.action_dim = action_dim self.max_memory_size = max_memory_size self.memory = deque(maxlen=max_memory_size) self.batch_size = batch_size self.exploration_rate = exploration_rate self.exploration_rate_decay = exploration_rate_decay self.exploration_rate_min = exploration_rate_min self.gamma = gamma self.curr_step = 0 self.learning_starts = learning_starts # min. experiences before training self.training_frequency = training_frequency # no. of experiences between updates to Q_online self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync self.save_every = save_frequency # no. of experiences between saving the network self.save_dir = save_dir self.use_cuda = torch.cuda.is_available() self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float() self.target_net = copy.deepcopy(self.online_net) # Q_target parameters are frozen. for p in self.target_net.parameters(): p.requires_grad = False if self.use_cuda: self.online_net = self.online_net(device='cuda') self.target_net = self.target_net(device='cuda') if checkpoint: self.load(checkpoint, reset_exploration_rate, load_replay_buffer) self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True) self.loss_fn = torch.nn.SmoothL1Loss() # self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate) # self.loss_fn = torch.nn.MSELoss() def act(self, state): """ Given a state, choose an epsilon-greedy action and update value of step. Inputs: state(LazyFrame): A single observation of the current state, dimension is (state_dim) Outputs: action_idx (int): An integer representing which action the agent will perform """ # EXPLORE if np.random.rand() < self.exploration_rate: action_idx = np.random.randint(self.action_dim) # EXPLOIT else: state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) state = state.unsqueeze(0) action_values = self.online_net(state) action_idx = torch.argmax(action_values, axis=1).item() # decrease exploration_rate self.exploration_rate *= self.exploration_rate_decay self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate) # increment step self.curr_step += 1 return action_idx def cache(self, state, next_state, action, reward, done): """ Store the experience to self.memory (replay buffer) Inputs: state (LazyFrame), next_state (LazyFrame), action (int), reward (float), done(bool)) """ print("####################################") print(state) state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state) action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action]) reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward]) done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done]) self.memory.append( (state, next_state, action, reward, done,) ) def recall(self): """ Retrieve a batch of experiences from memory """ batch = random.sample(self.memory, self.batch_size) state, next_state, action, reward, done = map(torch.stack, zip(*batch)) return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze() def td_estimate(self, states, actions): actions = actions.reshape(-1, 1) predicted_qs = self.online_net(states)# Q_online(s,a) predicted_qs = predicted_qs.gather(1, actions) return predicted_qs @torch.no_grad() def td_target(self, rewards, next_states, dones): rewards = rewards.reshape(-1, 1) dones = dones.reshape(-1, 1) target_qs = self.target_net.forward(next_states) target_qs = torch.max(target_qs, dim=1).values target_qs = target_qs.reshape(-1, 1) target_qs[dones] = 0.0 return (rewards + (self.gamma * target_qs)) def update_Q_online(self, td_estimate, td_target) : loss = self.loss_fn(td_estimate.float(), td_target.float()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def sync_Q_target(self): self.target_net.load_state_dict(self.online_net.state_dict()) def learn(self): if self.curr_step % self.target_network_sync_frequency == 0: self.sync_Q_target() if self.curr_step % self.save_every == 0: self.save() if self.curr_step < self.learning_starts: return None, None if self.curr_step % self.training_frequency != 0: return None, None # Sample from memory state, next_state, action, reward, done = self.recall() # Get TD Estimate td_est = self.td_estimate(state, action) # Get TD Target td_tgt = self.td_target(reward, next_state, done) # Backpropagate loss through Q_online loss = self.update_Q_online(td_est, td_tgt) return (td_est.mean().item(), loss) def save(self): save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt" torch.save( dict( model=self.online_net.state_dict(), exploration_rate=self.exploration_rate, replay_memory=self.memory ), save_path ) print(f"Airstriker model saved to {save_path} at step {self.curr_step}") def load(self, load_path, reset_exploration_rate, load_replay_buffer): if not load_path.exists(): raise ValueError(f"{load_path} does not exist") ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu')) exploration_rate = ckp.get('exploration_rate') state_dict = ckp.get('model') print(f"Loading model at {load_path} with exploration rate {exploration_rate}") self.online_net.load_state_dict(state_dict) self.target_net = copy.deepcopy(self.online_net) self.sync_Q_target() if load_replay_buffer: replay_memory = ckp.get('replay_memory') print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.") self.memory = replay_memory if replay_memory else self.memory if reset_exploration_rate: print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}") else: print(f"Setting exploration rate to {exploration_rate} not loaded.") self.exploration_rate = exploration_rate class DuelingDDQNAgent(DuelingDQNAgent): @torch.no_grad() def td_target(self, rewards, next_states, dones): rewards = rewards.reshape(-1, 1) dones = dones.reshape(-1, 1) q_vals = self.online_net.forward(next_states) target_actions = torch.argmax(q_vals, axis=1) target_actions = target_actions.reshape(-1, 1) target_qs = self.target_net.forward(next_states) target_qs = target_qs.gather(1, target_actions) target_qs = target_qs.reshape(-1, 1) target_qs[dones] = 0.0 return (rewards + (self.gamma * target_qs)) class DQNAgentWithStepDecay: def __init__(self, state_dim, action_dim, save_dir, checkpoint=None, learning_rate=0.00025, max_memory_size=100000, batch_size=32, exploration_rate=1, exploration_rate_decay=0.9999999, exploration_rate_min=0.1, training_frequency=1, learning_starts=1000, target_network_sync_frequency=500, reset_exploration_rate=False, save_frequency=100000, gamma=0.9, load_replay_buffer=True): self.state_dim = state_dim self.action_dim = action_dim self.max_memory_size = max_memory_size self.memory = deque(maxlen=max_memory_size) self.batch_size = batch_size self.exploration_rate = exploration_rate self.exploration_rate_decay = exploration_rate_decay self.exploration_rate_min = exploration_rate_min self.gamma = gamma self.curr_step = 0 self.learning_starts = learning_starts # min. experiences before training self.training_frequency = training_frequency # no. of experiences between updates to Q_online self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync self.save_every = save_frequency # no. of experiences between saving the network self.save_dir = save_dir self.use_cuda = torch.cuda.is_available() self.net = DQNet(self.state_dim, self.action_dim).float() if self.use_cuda: self.net = self.net.to(device='cuda') if checkpoint: self.load(checkpoint, reset_exploration_rate, load_replay_buffer) self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True) self.loss_fn = torch.nn.SmoothL1Loss() # self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate) # self.loss_fn = torch.nn.MSELoss() def act(self, state): """ Given a state, choose an epsilon-greedy action and update value of step. Inputs: state(LazyFrame): A single observation of the current state, dimension is (state_dim) Outputs: action_idx (int): An integer representing which action the agent will perform """ # EXPLORE if np.random.rand() < self.exploration_rate: action_idx = np.random.randint(self.action_dim) # EXPLOIT else: state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) state = state.unsqueeze(0) action_values = self.net(state, model='online') action_idx = torch.argmax(action_values, axis=1).item() # decrease exploration_rate self.exploration_rate *= self.exploration_rate_decay self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate) # increment step self.curr_step += 1 return action_idx def cache(self, state, next_state, action, reward, done): """ Store the experience to self.memory (replay buffer) Inputs: state (LazyFrame), next_state (LazyFrame), action (int), reward (float), done(bool)) """ state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state) next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state) action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action]) reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward]) done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done]) self.memory.append( (state, next_state, action, reward, done) ) def recall(self): """ Retrieve a batch of experiences from memory """ batch = random.sample(self.memory, self.batch_size) state, next_state, action, reward, done = map(torch.stack, zip(*batch)) return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze() def td_estimate(self, states, actions): actions = actions.reshape(-1, 1) predicted_qs = self.net(states, model='online')# Q_online(s,a) predicted_qs = predicted_qs.gather(1, actions) return predicted_qs @torch.no_grad() def td_target(self, rewards, next_states, dones): rewards = rewards.reshape(-1, 1) dones = dones.reshape(-1, 1) target_qs = self.net(next_states, model='target') target_qs = torch.max(target_qs, dim=1).values target_qs = target_qs.reshape(-1, 1) target_qs[dones] = 0.0 val = self.gamma * target_qs return (rewards + val) def update_Q_online(self, td_estimate, td_target) : loss = self.loss_fn(td_estimate.float(), td_target.float()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return loss.item() def sync_Q_target(self): self.net.target.load_state_dict(self.net.online.state_dict()) def learn(self): if self.curr_step % self.target_network_sync_frequency == 0: self.sync_Q_target() if self.curr_step % self.save_every == 0: self.save() if self.curr_step < self.learning_starts: return None, None if self.curr_step % self.training_frequency != 0: return None, None # Sample from memory state, next_state, action, reward, done = self.recall() # Get TD Estimate td_est = self.td_estimate(state, action) # Get TD Target td_tgt = self.td_target(reward, next_state, done) # Backpropagate loss through Q_online loss = self.update_Q_online(td_est, td_tgt) return (td_est.mean().item(), loss) def save(self): save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt" torch.save( dict( model=self.net.state_dict(), exploration_rate=self.exploration_rate, replay_memory=self.memory ), save_path ) print(f"Airstriker model saved to {save_path} at step {self.curr_step}") def load(self, load_path, reset_exploration_rate, load_replay_buffer): if not load_path.exists(): raise ValueError(f"{load_path} does not exist") ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu')) exploration_rate = ckp.get('exploration_rate') state_dict = ckp.get('model') print(f"Loading model at {load_path} with exploration rate {exploration_rate}") self.net.load_state_dict(state_dict) if load_replay_buffer: replay_memory = ckp.get('replay_memory') print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.") self.memory = replay_memory if replay_memory else self.memory if reset_exploration_rate: print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}") else: print(f"Setting exploration rate to {exploration_rate} not loaded.") self.exploration_rate = exploration_rate