lharri73 commited on
Commit
ea6b281
·
1 Parent(s): 1ac9ba4

added dp agent.

Browse files
scripts/dpAgent.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import enum
3
+ from matplotlib import pyplot as plt
4
+ from tqdm import trange
5
+ from numba import njit, prange
6
+
7
+ from stationaryGrid import StationaryGrid
8
+
9
+ class Action(enum.Enum):
10
+ UP = 0
11
+ DOWN = 1
12
+ LEFT = 2
13
+ RIGHT = 3
14
+
15
+
16
+ class DP:
17
+ def __init__(self, grid):
18
+ self.grid = grid
19
+ self.size = len(grid)
20
+ self.V = np.zeros((self.size, self.size))
21
+ self.gamma = 0.9
22
+ self.actions = [Action.UP, Action.DOWN, Action.LEFT, Action.RIGHT]
23
+
24
+ def rewardFunc(self, state, action):
25
+ if action == Action.UP:
26
+ finalPos = (state[0] - 1, state[1])
27
+ elif action == Action.DOWN:
28
+ finalPos = (state[0] + 1, state[1])
29
+ elif action == Action.LEFT:
30
+ finalPos = (state[0], state[1] - 1)
31
+ elif action == Action.RIGHT:
32
+ finalPos = (state[0], state[1] + 1)
33
+ else:
34
+ raise ValueError("Invalid action")
35
+
36
+ if finalPos[0] < 0 or finalPos[0] >= self.size or finalPos[1] < 0 or finalPos[1] >= self.size:
37
+ return state, -1
38
+ elif self.grid[finalPos[0], finalPos[1]] == 0:
39
+ return state, -1
40
+ elif finalPos[0] == 0 and finalPos[1] == 0:
41
+ return finalPos, 10
42
+
43
+ return finalPos, 0
44
+
45
+ # @njit(parallel=True)
46
+ def run(self, num_iterations):
47
+ for it in trange(num_iterations):
48
+ V_copy = np.copy(self.V)
49
+ for state in np.ndindex(*self.grid.shape):
50
+ weighted_rewards = 0
51
+ for action in self.actions:
52
+ finalPosition, reward = self.rewardFunc(state, action)
53
+ weighted_rewards += (1 / len(self.actions)) * (
54
+ reward + (self.gamma * self.V[finalPosition[0], finalPosition[1]]))
55
+ V_copy[state[0], state[1]] = weighted_rewards
56
+ self.V = V_copy
57
+
58
+ # plt.imshow(self.V)
59
+ # plt.savefig(f'imgs/{it}.png')
60
+ # print(it)
61
+
62
+ def policy(self, state):
63
+ """
64
+ The DP policy is to take the action that maximizes the value function.
65
+ This returns the best action and the final position after taking that action.
66
+ """
67
+ r = -np.inf
68
+ best = None
69
+ bestPos = None
70
+ for action in self.actions:
71
+ finalPosition, reward = self.rewardFunc(state, action)
72
+ if reward > r:
73
+ r = reward
74
+ best = action
75
+ bestPos = finalPosition
76
+ return best, bestPos, r
77
+
78
+ def find_path(self):
79
+ path = []
80
+ cur = (self.size - 1, self.size - 1)
81
+ path.append(cur)
82
+ i = 0
83
+ while cur != (0, 0):
84
+ _, cur, _ = self.policy(cur)
85
+ if cur in path:
86
+ print(path, cur)
87
+ raise ValueError("Infinite loop")
88
+ path.append(cur)
89
+ print(path)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ grid = StationaryGrid(4, size=20)
94
+ grid.create_grid()
95
+ dp = DP(grid.grid)
96
+ dp.run(10000)
97
+ dp.find_path()
98
+ plt.imshow(dp.V)
99
+ plt.savefig(f'imgs/{0}.png')
scripts/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ numpy
2
+ astar
3
+ matplotlib
4
+ tqdm
scripts/stationaryGrid.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from matplotlib import pyplot as plt
4
+ from astar import AStar
5
+ import pickle
6
+ import enum
7
+
8
+ class MazeSolver(AStar):
9
+ """
10
+ Because I'm too lazy to implement A-star, this class yoinked
11
+ from https://github.com/jrialland/python-astar/blob/f11311b678522d90c1786e6b8d9393095a0b733f/tests/maze/test_maze.py#L58
12
+
13
+ Sample use of the astar algorithm. In this exemple we work on a maze made of ascii characters,
14
+ and a 'node' is just a (x,y) tuple that represents a reachable position
15
+ """
16
+
17
+ def __init__(self, maze):
18
+ self.world = maze
19
+ self.size = maze.shape[0]
20
+
21
+ def heuristic_cost_estimate(self, n1, n2):
22
+ """computes the 'direct' distance between two (x,y) tuples"""
23
+ (x1, y1) = n1
24
+ (x2, y2) = n2
25
+ return np.hypot(x2 - x1, y2 - y1)
26
+
27
+ def distance_between(self, n1, n2):
28
+ """this method always returns 1, as two 'neighbors' are always adajcent"""
29
+ return 1
30
+
31
+ def neighbors(self, node):
32
+ """ for a given coordinate in the maze, returns up to 4 adjacent(north,east,south,west)
33
+ nodes that can be reached (=any adjacent coordinate that is not a wall)
34
+ """
35
+ x, y = node
36
+ return[(nx, ny) for nx, ny in[(x, y - 1), (x, y + 1), (x - 1, y), (x + 1, y)] if 0 <= nx < self.size and 0 <= ny < self.size and self.world[ny,nx] == 1]
37
+
38
+ class StationaryGrid:
39
+ def __init__(self, seed, size=20):
40
+ np.random.seed(seed)
41
+ random.seed(seed)
42
+ self.size = size
43
+ self.grid = np.ones((size, size),dtype=np.uint8)
44
+
45
+ def create_grid(self):
46
+ n_obstacles = np.random.randint(1, 10)
47
+ i = 0
48
+ while i < n_obstacles:
49
+ # for i in range(n_obstacles):
50
+ x = np.random.randint(0, self.grid.shape[0])
51
+ y = np.random.randint(0, self.grid.shape[1])
52
+ size = np.random.randint(2, high=self.size // 2, size=(2,))
53
+ if x == 0 and y == 0:
54
+ continue
55
+ if (x + size[0]) >= self.grid.shape[0] and (y + size[1]) >= self.grid.shape[1]:
56
+ continue
57
+
58
+ start = (0, 0)
59
+ goal = (self.grid.shape[0] - 2, self.grid.shape[0] - 2)
60
+ self.grid[x:x + size[0], y:y + size[1]] = 0
61
+ # make sure there's still a path to the goal
62
+ path = MazeSolver(self.grid).astar(start, goal)
63
+ if path is None:
64
+ # if not, undo the current obstacle and generate another random one
65
+ self.grid[x:x + size[0], y:y + size[1]] = 1
66
+ continue
67
+
68
+ i += 1
69
+
70
+ def plot(self, pth=None):
71
+ plt.imshow(self.grid, cmap='gray')
72
+ plt.plot(0, 0,marker='o', markersize=10, color="red")
73
+ plt.plot(self.grid.shape[0]-1, self.grid.shape[1]-1, marker='o', markersize=10, color="green")
74
+ if pth is not None:
75
+ plt.savefig(pth)
76
+ else:
77
+ plt.show()
78
+
79
+
80
+ if __name__ == '__main__':
81
+ for i in range(100):
82
+ grid = StationaryGrid(i, size=100)
83
+ grid.create_grid()
84
+ grid.plot(f'imgs/{i}.png')
85
+ print(i)