lharri73's picture
added dp agent.
ea6b281
import numpy as np
import enum
from matplotlib import pyplot as plt
from tqdm import trange
from numba import njit, prange
from stationaryGrid import StationaryGrid
class Action(enum.Enum):
UP = 0
DOWN = 1
LEFT = 2
RIGHT = 3
class DP:
def __init__(self, grid):
self.grid = grid
self.size = len(grid)
self.V = np.zeros((self.size, self.size))
self.gamma = 0.9
self.actions = [Action.UP, Action.DOWN, Action.LEFT, Action.RIGHT]
def rewardFunc(self, state, action):
if action == Action.UP:
finalPos = (state[0] - 1, state[1])
elif action == Action.DOWN:
finalPos = (state[0] + 1, state[1])
elif action == Action.LEFT:
finalPos = (state[0], state[1] - 1)
elif action == Action.RIGHT:
finalPos = (state[0], state[1] + 1)
else:
raise ValueError("Invalid action")
if finalPos[0] < 0 or finalPos[0] >= self.size or finalPos[1] < 0 or finalPos[1] >= self.size:
return state, -1
elif self.grid[finalPos[0], finalPos[1]] == 0:
return state, -1
elif finalPos[0] == 0 and finalPos[1] == 0:
return finalPos, 10
return finalPos, 0
# @njit(parallel=True)
def run(self, num_iterations):
for it in trange(num_iterations):
V_copy = np.copy(self.V)
for state in np.ndindex(*self.grid.shape):
weighted_rewards = 0
for action in self.actions:
finalPosition, reward = self.rewardFunc(state, action)
weighted_rewards += (1 / len(self.actions)) * (
reward + (self.gamma * self.V[finalPosition[0], finalPosition[1]]))
V_copy[state[0], state[1]] = weighted_rewards
self.V = V_copy
# plt.imshow(self.V)
# plt.savefig(f'imgs/{it}.png')
# print(it)
def policy(self, state):
"""
The DP policy is to take the action that maximizes the value function.
This returns the best action and the final position after taking that action.
"""
r = -np.inf
best = None
bestPos = None
for action in self.actions:
finalPosition, reward = self.rewardFunc(state, action)
if reward > r:
r = reward
best = action
bestPos = finalPosition
return best, bestPos, r
def find_path(self):
path = []
cur = (self.size - 1, self.size - 1)
path.append(cur)
i = 0
while cur != (0, 0):
_, cur, _ = self.policy(cur)
if cur in path:
print(path, cur)
raise ValueError("Infinite loop")
path.append(cur)
print(path)
if __name__ == "__main__":
grid = StationaryGrid(4, size=20)
grid.create_grid()
dp = DP(grid.grid)
dp.run(10000)
dp.find_path()
plt.imshow(dp.V)
plt.savefig(f'imgs/{0}.png')