|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import random |
|
from collections import deque |
|
import heapq |
|
|
|
|
|
class QNetwork(nn.Module): |
|
def __init__(self, state_size, action_size, hidden_sizes=[256, 128, 64], dropout_rate=0.1): |
|
super(QNetwork, self).__init__() |
|
self.state_size = state_size |
|
self.action_size = action_size |
|
|
|
|
|
layers = [] |
|
prev_size = state_size |
|
for hidden_size in hidden_sizes: |
|
layers.append(nn.Linear(prev_size, hidden_size)) |
|
layers.append(nn.BatchNorm1d(hidden_size)) |
|
layers.append(nn.ReLU()) |
|
layers.append(nn.Dropout(dropout_rate)) |
|
prev_size = hidden_size |
|
layers.append(nn.Linear(prev_size, action_size)) |
|
|
|
self.network = nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.network(x) |
|
|
|
|
|
class PriorityReplayMemory: |
|
def __init__(self, capacity, alpha=0.6, beta=0.4, beta_increment=0.001): |
|
self.capacity = capacity |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.beta_increment = beta_increment |
|
self.memory = [] |
|
self.experiences = deque(maxlen=capacity) |
|
self.max_priority = 1.0 |
|
|
|
def add(self, experience, error=None): |
|
priority = error if error is not None else self.max_priority |
|
priority = (abs(priority) + 1e-5) ** self.alpha |
|
heapq.heappush(self.memory, (-priority, len(self.experiences))) |
|
self.experiences.append(experience) |
|
|
|
def sample(self, batch_size): |
|
if len(self.experiences) < batch_size: |
|
return None, None, None |
|
|
|
|
|
priorities = np.array([-p for p, _ in self.memory[:len(self.experiences)]]) |
|
probs = priorities / priorities.sum() |
|
|
|
|
|
indices = np.random.choice(len(self.experiences), batch_size, p=probs, replace=False) |
|
samples = [self.experiences[idx] for idx in indices] |
|
|
|
|
|
weights = (len(self.experiences) * probs[indices]) ** (-self.beta) |
|
weights /= weights.max() |
|
|
|
self.beta = min(1.0, self.beta + self.beta_increment) |
|
return samples, indices, torch.FloatTensor(weights) |
|
|
|
def update_priorities(self, indices, errors): |
|
for idx, error in zip(indices, errors): |
|
priority = (abs(error) + 1e-5) ** self.alpha |
|
self.memory[idx] = (-priority, self.memory[idx][1]) |
|
self.max_priority = max(self.max_priority, priority) |
|
heapq.heapify(self.memory) |
|
|
|
def __len__(self): |
|
return len(self.experiences) |
|
|
|
|
|
class RLAgent: |
|
def __init__(self, state_size, action_size, |
|
lr=0.0005, gamma=0.99, epsilon=1.0, epsilon_decay=0.995, min_epsilon=0.01, |
|
memory_capacity=10000, batch_size=64, target_update_freq=1000, |
|
use_double_dqn=True, clip_grad_norm=1.0): |
|
self.state_size = state_size |
|
self.action_size = action_size |
|
self.gamma = gamma |
|
self.epsilon = epsilon |
|
self.epsilon_decay = epsilon_decay |
|
self.min_epsilon = min_epsilon |
|
self.batch_size = batch_size |
|
self.use_double_dqn = use_double_dqn |
|
self.clip_grad_norm = clip_grad_norm |
|
|
|
|
|
self.policy_net = QNetwork(state_size, action_size) |
|
self.target_net = QNetwork(state_size, action_size) |
|
self.target_net.load_state_dict(self.policy_net.state_dict()) |
|
self.target_net.eval() |
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr) |
|
self.criterion = nn.SmoothL1Loss() |
|
|
|
|
|
self.memory = PriorityReplayMemory(memory_capacity) |
|
self.steps = 0 |
|
self.target_update_freq = target_update_freq |
|
|
|
def remember(self, state, action, reward, next_state, done): |
|
|
|
state_tensor = torch.FloatTensor(state).unsqueeze(0) |
|
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0) |
|
with torch.no_grad(): |
|
q_value = self.policy_net(state_tensor)[action] |
|
next_q = self.target_net(next_state_tensor).max().item() |
|
target = reward + (1 - done) * self.gamma * next_q |
|
error = abs(q_value.item() - target) |
|
self.memory.add((state, action, reward, next_state, done), error) |
|
|
|
def act(self, state): |
|
self.steps += 1 |
|
if random.random() < self.epsilon: |
|
return random.randint(0, self.action_size - 1) |
|
state = torch.FloatTensor(state).unsqueeze(0) |
|
with torch.no_grad(): |
|
return torch.argmax(self.policy_net(state)).item() |
|
|
|
def train(self): |
|
if len(self.memory) < self.batch_size: |
|
return |
|
|
|
|
|
batch, indices, weights = self.memory.sample(self.batch_size) |
|
if batch is None: |
|
return |
|
|
|
states, actions, rewards, next_states, dones = zip(*batch) |
|
states = torch.FloatTensor(states) |
|
actions = torch.LongTensor(actions) |
|
rewards = torch.FloatTensor(rewards) |
|
next_states = torch.FloatTensor(next_states) |
|
dones = torch.FloatTensor(dones) |
|
weights = weights.unsqueeze(1) |
|
|
|
|
|
q_values = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze(1) |
|
|
|
|
|
if self.use_double_dqn: |
|
next_actions = self.policy_net(next_states).argmax(1) |
|
next_q_values = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze(1) |
|
else: |
|
next_q_values = self.target_net(next_states).max(1)[0] |
|
|
|
|
|
targets = rewards + self.gamma * next_q_values * (1 - dones) |
|
|
|
|
|
td_errors = (q_values - targets).detach().cpu().numpy() |
|
|
|
|
|
loss = (self.criterion(q_values, targets) * weights).mean() |
|
|
|
|
|
self.optimizer.zero_grad() |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), self.clip_grad_norm) |
|
self.optimizer.step() |
|
|
|
|
|
self.memory.update_priorities(indices, td_errors) |
|
|
|
|
|
if self.steps % self.target_update_freq == 0: |
|
self.target_net.load_state_dict(self.policy_net.state_dict()) |
|
|
|
|
|
self.epsilon = max(self.min_epsilon, self.epsilon * self.epsilon_decay) |