Spaces:
Sleeping
Sleeping
import numpy as np | |
import enum | |
from matplotlib import pyplot as plt | |
from tqdm import trange | |
from numba import njit, prange | |
from stationaryGrid import StationaryGrid | |
class Action(enum.Enum): | |
UP = 0 | |
DOWN = 1 | |
LEFT = 2 | |
RIGHT = 3 | |
class DP: | |
def __init__(self, grid): | |
self.grid = grid | |
self.size = len(grid) | |
self.V = np.zeros((self.size, self.size)) | |
self.gamma = 0.9 | |
self.actions = [Action.UP, Action.DOWN, Action.LEFT, Action.RIGHT] | |
def rewardFunc(self, state, action): | |
if action == Action.UP: | |
finalPos = (state[0] - 1, state[1]) | |
elif action == Action.DOWN: | |
finalPos = (state[0] + 1, state[1]) | |
elif action == Action.LEFT: | |
finalPos = (state[0], state[1] - 1) | |
elif action == Action.RIGHT: | |
finalPos = (state[0], state[1] + 1) | |
else: | |
raise ValueError("Invalid action") | |
if finalPos[0] < 0 or finalPos[0] >= self.size or finalPos[1] < 0 or finalPos[1] >= self.size: | |
return state, -1 | |
elif self.grid[finalPos[0], finalPos[1]] == 0: | |
return state, -1 | |
elif finalPos[0] == 0 and finalPos[1] == 0: | |
return finalPos, 10 | |
return finalPos, 0 | |
# @njit(parallel=True) | |
def run(self, num_iterations): | |
for it in trange(num_iterations): | |
V_copy = np.copy(self.V) | |
for state in np.ndindex(*self.grid.shape): | |
weighted_rewards = 0 | |
for action in self.actions: | |
finalPosition, reward = self.rewardFunc(state, action) | |
weighted_rewards += (1 / len(self.actions)) * ( | |
reward + (self.gamma * self.V[finalPosition[0], finalPosition[1]])) | |
V_copy[state[0], state[1]] = weighted_rewards | |
self.V = V_copy | |
# plt.imshow(self.V) | |
# plt.savefig(f'imgs/{it}.png') | |
# print(it) | |
def policy(self, state): | |
""" | |
The DP policy is to take the action that maximizes the value function. | |
This returns the best action and the final position after taking that action. | |
""" | |
r = -np.inf | |
best = None | |
bestPos = None | |
for action in self.actions: | |
finalPosition, reward = self.rewardFunc(state, action) | |
if reward > r: | |
r = reward | |
best = action | |
bestPos = finalPosition | |
return best, bestPos, r | |
def find_path(self): | |
path = [] | |
cur = (self.size - 1, self.size - 1) | |
path.append(cur) | |
i = 0 | |
while cur != (0, 0): | |
_, cur, _ = self.policy(cur) | |
if cur in path: | |
print(path, cur) | |
raise ValueError("Infinite loop") | |
path.append(cur) | |
print(path) | |
if __name__ == "__main__": | |
grid = StationaryGrid(4, size=20) | |
grid.create_grid() | |
dp = DP(grid.grid) | |
dp.run(10000) | |
dp.find_path() | |
plt.imshow(dp.V) | |
plt.savefig(f'imgs/{0}.png') | |