Upload 36 files
Browse files- Makefile +30 -0
- README.md +358 -0
- environment.atari.yml +153 -0
- environment.procgen-v2.yml +135 -0
- environment.procgen.yml +135 -0
- requirements-v1.txt +76 -0
- requirements.txt +42 -0
- src/airstriker-genesis/__init__.py +0 -0
- src/airstriker-genesis/agent.py +400 -0
- src/airstriker-genesis/cartpole.py +353 -0
- src/airstriker-genesis/procgen_agent.py +400 -0
- src/airstriker-genesis/replay.py +66 -0
- src/airstriker-genesis/run-airstriker-ddqn.py +120 -0
- src/airstriker-genesis/run-airstriker-dqn.py +115 -0
- src/airstriker-genesis/run-cartpole.py +120 -0
- src/airstriker-genesis/test.py +405 -0
- src/airstriker-genesis/utils.py +22 -0
- src/airstriker-genesis/wrappers.py +213 -0
- src/lunar-lander/agent.py +1104 -0
- src/lunar-lander/params.py +12 -0
- src/lunar-lander/replay.py +67 -0
- src/lunar-lander/run-lunar-ddqn.py +45 -0
- src/lunar-lander/run-lunar-dqn.py +46 -0
- src/lunar-lander/run-lunar-dueling-ddqn.py +47 -0
- src/lunar-lander/run-lunar-dueling-dqn.py +46 -0
- src/lunar-lander/train.py +84 -0
- src/lunar-lander/wrappers.py +193 -0
- src/procgen/agent.py +664 -0
- src/procgen/run-starpilot-ddqn.py +45 -0
- src/procgen/run-starpilot-dqn.py +45 -0
- src/procgen/run-starpilot-dueling-ddqn.py +45 -0
- src/procgen/run-starpilot-dueling-dqn.py +45 -0
- src/procgen/test-procgen.py +12 -0
- src/procgen/train.py +48 -0
- src/procgen/wrappers.py +187 -0
- troubleshooting.md +37 -0
Makefile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: create-atari-env
|
2 |
+
create-atari-env: ## Creates conda environment
|
3 |
+
conda env create -f environment.atari-yml --force
|
4 |
+
|
5 |
+
.PHONY: create-procgen-env
|
6 |
+
create-procgen-env: ## Creates conda environment
|
7 |
+
conda env create -f environment.procgen.yml --force
|
8 |
+
|
9 |
+
.PHONY: setup-env
|
10 |
+
setup-env: ## Sets up conda environment
|
11 |
+
conda install pytorch torchvision numpy -c pytorch -y
|
12 |
+
pip install gym-retro
|
13 |
+
pip install "gym[atari]==0.21.0"
|
14 |
+
pip install importlib-metadata==4.13.0
|
15 |
+
|
16 |
+
.PHONY: run-air-dqn
|
17 |
+
run-air-dqn: ## Runs
|
18 |
+
python ./src/airstriker-genesis/run-airstriker-dqn.py
|
19 |
+
|
20 |
+
.PHONY: run-air-ddqn
|
21 |
+
run-air-ddqn: ## Runs
|
22 |
+
python ./src/airstriker-genesis/run-airstriker-ddqn.py
|
23 |
+
|
24 |
+
.PHONY: run-starpilot-dqn
|
25 |
+
run-starpilot-dqn: ## Runs
|
26 |
+
python ./src/procgen/run-starpilot-dqn.py
|
27 |
+
|
28 |
+
.PHONY: run-starpilot-ddqn
|
29 |
+
run-starpilot-ddqn: ## Runs
|
30 |
+
python ./src/procgen/run-starpilot-ddqn.py
|
README.md
CHANGED
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# **Abstract**
|
2 |
+
|
3 |
+
On January 1, 2013, DeepMind published a paper called "Playing Atari
|
4 |
+
with Deep Reinforcement Learning" introducing their algorithm called
|
5 |
+
Deep Q-Network (DQN) which revolutionized the field of reinforcement
|
6 |
+
learning. For the first time they had brought together Deep Learning and
|
7 |
+
Q-learning and showed impressive results applying deep reinforcement
|
8 |
+
learning to Atari games with their agents performing at or over human
|
9 |
+
level expertise in almost all the games trained on.
|
10 |
+
A Deep Q-Network utilizes a deep neural network to estimate the q-values
|
11 |
+
for each action, allowing the policy to select the action with the
|
12 |
+
maximum q-values. This use of deep neural network to get q-values was
|
13 |
+
immensely superior to implementing q-table look-ups and widened the
|
14 |
+
applicability of q-learning to more complex reinforcement learning
|
15 |
+
environments.
|
16 |
+
While revolutionary, the original version of DQN had a few problems,
|
17 |
+
especially its slow/inefficient learning process. Over these past 9
|
18 |
+
years, a few improved versions of DQNs have become popular. This project
|
19 |
+
is an attempt to study the effectiveness of a few of these DQN flavors,
|
20 |
+
what problems they solve and compare their performance in the same
|
21 |
+
reinforcement learning environment.
|
22 |
+
|
23 |
+
# Deep Q-Networks and its flavors
|
24 |
+
|
25 |
+
- **Vanilla DQN**
|
26 |
+
|
27 |
+
The vanilla (original) DQN uses 2 neural networks: the **online**
|
28 |
+
network and the **target** network. The online network is the main
|
29 |
+
neural network that the agent uses to select the best action for a
|
30 |
+
given state. The target neural network is usually a copy of the
|
31 |
+
online network. It is used to get the "target" q-values for each
|
32 |
+
action for a particular state. i.e. During the learning phase, since
|
33 |
+
we don’t have actual ground truths for future q-values, these
|
34 |
+
q-values from the target network will be used as labels optimize the
|
35 |
+
network.
|
36 |
+
|
37 |
+
The target network calculates the target q-values by using the
|
38 |
+
following Bellman equation: \[\begin{aligned}
|
39 |
+
Q(s_t, a_t) =
|
40 |
+
r_{t+1} + \gamma \max _{a_{t+1} \in A} Q(s_{t+1}, a_{t+1})
|
41 |
+
\end{aligned}\] where,
|
42 |
+
\(Q(s_t, a_t)\) = The target q-value (ground truth) for a past
|
43 |
+
experience in the replay memory
|
44 |
+
|
45 |
+
\(r_{t+1}\)= The reward that was obtained for taking the chosen
|
46 |
+
action in that particular experience
|
47 |
+
|
48 |
+
\(\gamma\)= The discount factor for future rewards
|
49 |
+
|
50 |
+
\(Q(s_{t+1}, a_{t+1})\) = The q-value for best action (based on the
|
51 |
+
policy) for the next state for that particular experience
|
52 |
+
|
53 |
+
- **Double DQN**
|
54 |
+
|
55 |
+
One of the problems with vanilla DQN is the way it calculates its
|
56 |
+
target values (ground-truth). We can see from the bellman equation
|
57 |
+
above that the target network uses the **max** q-value directly in
|
58 |
+
the equation. This is found to almost always overestimate the
|
59 |
+
q-value because using the **max** function introduces the
|
60 |
+
maximization-bias to our estimates. Using max will give the largest
|
61 |
+
value even if that specific max value was an outlier, thus skewing
|
62 |
+
our estimates.
|
63 |
+
The Double DQN solves this problem by changing the original
|
64 |
+
algorithm to the following:
|
65 |
+
|
66 |
+
1. Instead of using the **max** function, first use the online
|
67 |
+
network to estimate the best action for the next state
|
68 |
+
|
69 |
+
2. Calculate target q-values for the next state for each possible
|
70 |
+
action using the target network
|
71 |
+
|
72 |
+
3. From the q-values calculated by the target network, use the
|
73 |
+
q-value of the action chosen in step 1.
|
74 |
+
|
75 |
+
This can be represented by the following equation: \[\begin{aligned}
|
76 |
+
Q(s_t, a_t) =
|
77 |
+
r_{t+1} + \gamma Q_{target}(s_{t+1}, a'_{t+1})
|
78 |
+
\end{aligned}\] where, \[\begin{aligned}
|
79 |
+
a'_{t+1} = argmax({Q_{online}(s_{t+1})})
|
80 |
+
\end{aligned}\]
|
81 |
+
|
82 |
+
- **Dueling DQN**
|
83 |
+
|
84 |
+
The Dueling DQN algorithm was an attempt to improve upon the
|
85 |
+
original DQN algorithm by changing the architecture of the neural
|
86 |
+
network used in Deep Q-learning. The Duelling DQN algorithm splits
|
87 |
+
the last layer of the DQN into to parts, a **value stream** and an
|
88 |
+
**advantage stream**, the outputs of which are aggregated in an
|
89 |
+
aggregating layer that gives the final q-value. One of the main
|
90 |
+
problems with the original DQN algorithm was that the difference in
|
91 |
+
Q-values for the actions were often very close. Thus, selecting the
|
92 |
+
action with the max q-value might always not be the best action to
|
93 |
+
take. The Dueling DQN attempts to mitigate this by using advantage,
|
94 |
+
which is a measure of how better an action is compared to other
|
95 |
+
actions for a given state. The value stream, on the other hand,
|
96 |
+
learns how good/bad it is to be in a specific state. eg. Moving
|
97 |
+
straight towards an obstacle in a racing game, being in the path of
|
98 |
+
a projectile in Space Invaders, etc. Instead of learning to predict
|
99 |
+
a single q-value, by separating into value and advantage streams
|
100 |
+
helps the network generalize better.
|
101 |
+
|
102 |
+
![image](./docs/dueling.png)
|
103 |
+
Fig: The Dueling DQN architecture (Image taken from the original
|
104 |
+
paper by Wang et al.)
|
105 |
+
|
106 |
+
|
107 |
+
The q-value in a Dueling DQN architecture is given by
|
108 |
+
\[\begin{aligned}
|
109 |
+
Q(s_t, a_t) = V(s_t) + A(a)
|
110 |
+
\end{aligned}\] where,
|
111 |
+
V(s\_t) = The value of the current state (how advantageous it is to
|
112 |
+
be in that state)
|
113 |
+
|
114 |
+
A(a) =The advantage of taking action an a at that state
|
115 |
+
|
116 |
+
# About the project
|
117 |
+
|
118 |
+
My original goal for the project was to train an agent using DQN to
|
119 |
+
play **Airstriker Genesis**, a space shooting game and evaluate the
|
120 |
+
same agent’s performance on another similar game called
|
121 |
+
**Starpilot**. Unfortunately, I was unable to train a decent enough
|
122 |
+
agent in the first game, which made it meaningless to evaluate it’s
|
123 |
+
performance on yet another game.
|
124 |
+
|
125 |
+
Because I still want to do the original project some time in the
|
126 |
+
future, to prepare myself for that I thought it would be better to
|
127 |
+
first learn in-depth about how Deep Q-Networks work, what their
|
128 |
+
shortcomings are and how they can be improved. This, and for
|
129 |
+
time-constraint reasons, I have changed my project for this class to
|
130 |
+
a comparison of various DQN versions.
|
131 |
+
|
132 |
+
# Dataset
|
133 |
+
|
134 |
+
I used the excellent [Gym](https://github.com/openai/gym) library to
|
135 |
+
run my environment. A total of 9 agents, 1 in Airstriker Genesis, 4
|
136 |
+
in Starpilot and 4 in Lunar Lander were trained.
|
137 |
+
|
138 |
+
| **Game** | **Observation Space** | **Action Space** |
|
139 |
+
| :----------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
140 |
+
| Airstriker Genesis | RGB values of each pixel of the game screen (255, 255, 3) | Discrete(12) representing each of the buttons on the old Atari controllers. But since only three of those buttons were used in the game the action space was reduced to 3 during training. ( Left, Right, Fire ) |
|
141 |
+
| Starpilot | RGB values of each pixel of the game screen (64, 64, 3) | Discrete(15) representing each of the button combos ( Left, Right, Up, Down, Up + Right, Up + Left, Down + Right, Down + Left, W, A, S, D, Q, E, Do nothing ) |
|
142 |
+
| Lunar Lander | 8-dimensional vector: ( X-coordinate, Y-coordinate, Linear velocity in X, Linear Velocity in Y, Angle, Angular Velocity, Boolean (Leg 1 in contact with ground), Boolean (Leg 2 in contact with ground) ) | Discrete(4)( Do nothing, Fire left engine, Fire main engine, Fire right engine ) |
|
143 |
+
|
144 |
+
|
145 |
+
**Environment/Libraries**:
|
146 |
+
Miniconda, Python 3.9, Gym, Pyorch, Numpy, Tensorboard on my
|
147 |
+
personal Macbook Pro (M1)
|
148 |
+
|
149 |
+
# ML Methodology
|
150 |
+
|
151 |
+
Each agent was trained using DQN or one of its flavors. Each agent
|
152 |
+
for a particular game was trained with the same hyperparameters with
|
153 |
+
just the underlying algorithm different. The following metrics for
|
154 |
+
each agent were used for evaluation:
|
155 |
+
|
156 |
+
- **Epsilon value over each episode** Shows what the exploration
|
157 |
+
rate was at the end of each episode.
|
158 |
+
|
159 |
+
- **Average Q-value for the last 100 episodes** A measure of the
|
160 |
+
average q-value (for the action chosen) for the last 100
|
161 |
+
episodes.
|
162 |
+
|
163 |
+
- **Average length for the last 100 episodes** A measure of the
|
164 |
+
average number of steps taken in each episode
|
165 |
+
|
166 |
+
- **Average loss for the last 100 episodes** A measure of loss
|
167 |
+
during learning in the last 100 episodes (A Huber Loss was used)
|
168 |
+
|
169 |
+
- **Average reward for the last 100 episodes** A measure of the
|
170 |
+
average reward the agent accumulated over the last 100 episodes
|
171 |
+
|
172 |
+
## Preprocessing
|
173 |
+
|
174 |
+
For the Airstriker and the Starpilot games:
|
175 |
+
|
176 |
+
1. Changed each frame to grayscale
|
177 |
+
Since the color shouldn’t matter to the agent, I decided to
|
178 |
+
change the RGB image to grayscale
|
179 |
+
|
180 |
+
2. Changed observation space shape from (height, width, channels)
|
181 |
+
to (channels, height, width) to make it compatible with
|
182 |
+
Pytorch
|
183 |
+
Apparently Pytorch uses a different format than the direct
|
184 |
+
output of the gym environment. For this reason, I had to reshape
|
185 |
+
each observation to match Pytorch’s scheme (this took me a very
|
186 |
+
long time to figure out, but had an "Aha\!" moment when I
|
187 |
+
remember you saying something similar in class).
|
188 |
+
|
189 |
+
3. Framestacking
|
190 |
+
Instead of processing 1 frame at a time, process 4 frames at a
|
191 |
+
time. This is because just 1 frame is not enough information for
|
192 |
+
the agent to decide what action to take.
|
193 |
+
|
194 |
+
For Lunar Lander, since the reward changes are very drastic (sudden
|
195 |
+
+100, -100, +200) rewards, I experimented with Reward Clipping
|
196 |
+
(clipping the rewards to \[-1, 1\] range) but this didn’t seem to
|
197 |
+
make much difference in my agent’s performance.
|
198 |
+
|
199 |
+
# Results
|
200 |
+
|
201 |
+
- **Airstriker Genesis**
|
202 |
+
The loss went down until about 5200 episodes but after that it
|
203 |
+
stopped going down any further. Consequently the average reward the
|
204 |
+
agent accumulated over the last 100 episodes pretty much plateaued
|
205 |
+
after about 5000 episodes. On analysis, I noticed that my
|
206 |
+
exploration rate at the end of the 7000th episode was still about
|
207 |
+
0.65, which means that the agent was taking random actions more than
|
208 |
+
half of the time. On hindsight, I feel like I should have trained
|
209 |
+
more, at least until the epsilon value (exploration rate) completely
|
210 |
+
decayed to 5%.
|
211 |
+
![image](./docs/air1.png) ![image](./docs/air2.png) ![image](./docs/air3.png)
|
212 |
+
|
213 |
+
|
214 |
+
- **Starpilot**
|
215 |
+
|
216 |
+
I trained DQN, Double DQN, Dueling DQN and Dueling Double DQN
|
217 |
+
versions for this game to compare the different algorithms.
|
218 |
+
From the graph of mean q-values, we can tell that the Vanilla DQN
|
219 |
+
versions indeed give high q-values, and their Double-DQN couterparts
|
220 |
+
give lower values, which makes me think that my implementation of
|
221 |
+
the Double DQN algorithm was OK. I had expected the agent to
|
222 |
+
accumulate higher rewards starting much earlier for the Double and
|
223 |
+
Dueling versions, but since the average rewards was almost similar
|
224 |
+
for all the agents, I could not notice any stark differences between
|
225 |
+
the performance of each agent.
|
226 |
+
|
227 |
+
![image](./docs/star1.png)
|
228 |
+
|
229 |
+
![image](./docs/star2.png)
|
230 |
+
|
231 |
+
| | |
|
232 |
+
| :------------------ | :------------------ |
|
233 |
+
| ![image](./docs/star3.png) | ![image](./docs/star4.png) |
|
234 |
+
|
235 |
+
|
236 |
+
- **Lunar Lander**
|
237 |
+
|
238 |
+
Since I did gain much insight from the agent in the Starpilot game,
|
239 |
+
I thought I was not training long enough. So I tried training the
|
240 |
+
same agents on Lunar Lander, which is a comparatively simpler game
|
241 |
+
with a smaller observation space and one that a DQN algorithm should
|
242 |
+
be able converge pretty quickly to (based on comments by other
|
243 |
+
people in the RL community).
|
244 |
+
![image](./docs/lunar1.png)
|
245 |
+
|
246 |
+
![image](./docs/lunar2.png)
|
247 |
+
|
248 |
+
| | |
|
249 |
+
| :------------------- | :------------------- |
|
250 |
+
| ![image](./docs/lunar3.png) | ![image](./docs/lunar4.png) |
|
251 |
+
|
252 |
+
|
253 |
+
|
254 |
+
The results for this were interesting. Although I did not find any
|
255 |
+
vast difference between the different variations of the DQN
|
256 |
+
algorithm, I found that the performance of my agent suddenly got
|
257 |
+
worse at around 300 episodes. Upon researching on why this may have
|
258 |
+
happened, I learned that DQN agents suffer from **catastrophic
|
259 |
+
forgetting** i.e. after training extensively, the network suddenly
|
260 |
+
forgets what it has learned in the past and the starts performing
|
261 |
+
worse. Initially, I thought this might have been the case, but since
|
262 |
+
I haven’t trained long enough, and because all models started
|
263 |
+
performing worse at almost exactly the same episode number, I think
|
264 |
+
this might be a problem with my code or some hyperparameter that I
|
265 |
+
used.
|
266 |
+
|
267 |
+
Upon checking what the agent was doing in the actual game, I found
|
268 |
+
that it was playing it very safe and just constantly hovering in the
|
269 |
+
air, not attempting to land the spaceship (the goal of the agent is
|
270 |
+
to land within the yellow flags). I thought maybe penalizing the
|
271 |
+
rewards for taking too many steps in the episode would work, but
|
272 |
+
that didn’t help either.
|
273 |
+
|
274 |
+
![image](./docs/check.png)
|
275 |
+
|
276 |
+
# Problems Faced
|
277 |
+
|
278 |
+
|
279 |
+
Here are a few of the problems that I faced while training my agents:
|
280 |
+
|
281 |
+
- Understanding the various hyperparameters in the algorithm. DQN uses
|
282 |
+
a lot of moving parts and thus, tuning each parameter was a
|
283 |
+
difficult task. There were about 8 different hyperparameters (some
|
284 |
+
correlated) that impacted the agent’s training performance. I
|
285 |
+
struggled with understanding how each parameter impacted the agent
|
286 |
+
and also with figuring out how to find optimal values for those. I
|
287 |
+
ended up tuning them by trial and error.
|
288 |
+
|
289 |
+
- I got stuck for a long time figuring out why my convolutional layer
|
290 |
+
was not working. I didn’t realize that Pytorch has the channels in
|
291 |
+
the first dimension, and because of that, I was passing huge numbers
|
292 |
+
like 255 (the height of the image) into the input dimension for a
|
293 |
+
Conv2D layer.
|
294 |
+
|
295 |
+
- I struggled with knowing how long is long enough to realize that a
|
296 |
+
model is not working. I trained a model on Airstriker Genesis for 14
|
297 |
+
hours just to realize later that I had set a parameter incorrectly
|
298 |
+
and had to retrain all over again.
|
299 |
+
|
300 |
+
# What Next?
|
301 |
+
|
302 |
+
Although I didn’t get a final working agent for any of the games I
|
303 |
+
tried, I feel like I have learned a lot about reinforcement learning,
|
304 |
+
especially about Deep Q-learning. I plan to improve upon this further,
|
305 |
+
and hopefully get an agent to go far into at least one of the games.
|
306 |
+
Next time, I will start with first debugging my current code and see if
|
307 |
+
I have any implementation mistakes. Then I will train them a lot longer
|
308 |
+
than I did this time and see if it works. While learning about the
|
309 |
+
different flavors of DQN, I also learned a little about NoisyNet DQN,
|
310 |
+
Rainbow-DQN and Prioritized Experience Replay. I couln’t implement these
|
311 |
+
for this project, but I would like to try them out some time soon.
|
312 |
+
|
313 |
+
# Lessons Learned
|
314 |
+
|
315 |
+
- Reinforcement learning is a very challenging problem. It takes a
|
316 |
+
substantially large amount of time to train, it is hard to debug and
|
317 |
+
it is very difficult to tune its hyperparameters just right. It is a
|
318 |
+
lot different from supervised learning in that there are no actual
|
319 |
+
labels and thus, this makes optimization very difficult.
|
320 |
+
|
321 |
+
- I tried training an agent on the Atari Airstriker Genesis and the
|
322 |
+
procgen Starpilot game using just the CPU, but this took a very long
|
323 |
+
time. This is understandable because the inputs are images and using
|
324 |
+
a GPU would have been obviously better. Next time, I will definitely
|
325 |
+
try using a GPU to make training faster.
|
326 |
+
|
327 |
+
- Upon being faced with the problem of my agent not learning, I went
|
328 |
+
into research mode and got to learn a lot about DQN and its improved
|
329 |
+
versions. I am not a master of the algorithms yet (I have yet to get
|
330 |
+
an agent to perform well in the game), but I feel like I understand
|
331 |
+
how each version works.
|
332 |
+
|
333 |
+
- Rather than just following someone’s tutorial, also reading the
|
334 |
+
actual papers for that particular algorithm helped me understand the
|
335 |
+
algorithm better and code it.
|
336 |
+
|
337 |
+
- Doing this project reinforced into me that I love the concept of
|
338 |
+
reinforcement learning. It has made me even more interested into
|
339 |
+
exploring the field further and learn more.
|
340 |
+
|
341 |
+
# References / Resources
|
342 |
+
|
343 |
+
- [Reinforcement Learning (DQN) Tutorial, Adam
|
344 |
+
Paszke](https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html)
|
345 |
+
|
346 |
+
- [Train a mario-playing RL agent, Yuansong Feng, Suraj Subramanian,
|
347 |
+
Howard Wang, Steven
|
348 |
+
Guo](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html)
|
349 |
+
|
350 |
+
- [About Double DQN, Dueling
|
351 |
+
DQN](https://horomary.hatenablog.com/entry/2021/02/06/013412)
|
352 |
+
|
353 |
+
- [Dueling Network Architecture for Deep Reinforcement Learning (Wang
|
354 |
+
et al., 2015))](https://arxiv.org/abs/1511.06581)
|
355 |
+
|
356 |
+
|
357 |
+
*(Final source code for the project can be found*
|
358 |
+
[*here*](https://github.com/00ber/ml-reinforcement-learning)*)*.
|
environment.atari.yml
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: mlrl
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- absl-py=1.3.0=py37hecd8cb5_0
|
7 |
+
- aiohttp=3.8.3=py37h6c40b1e_0
|
8 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
9 |
+
- appnope=0.1.2=py37hecd8cb5_1001
|
10 |
+
- async-timeout=4.0.2=py37hecd8cb5_0
|
11 |
+
- asynctest=0.13.0=py_0
|
12 |
+
- attrs=22.1.0=py37hecd8cb5_0
|
13 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
14 |
+
- blas=1.0=mkl
|
15 |
+
- blinker=1.4=py37hecd8cb5_0
|
16 |
+
- brotli=1.0.9=hca72f7f_7
|
17 |
+
- brotli-bin=1.0.9=hca72f7f_7
|
18 |
+
- brotlipy=0.7.0=py37h9ed2024_1003
|
19 |
+
- bzip2=1.0.8=h1de35cc_0
|
20 |
+
- c-ares=1.18.1=hca72f7f_0
|
21 |
+
- ca-certificates=2022.10.11=hecd8cb5_0
|
22 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
23 |
+
- cairo=1.14.12=hc4e6be7_4
|
24 |
+
- certifi=2022.9.24=py37hecd8cb5_0
|
25 |
+
- cffi=1.15.0=py37hca72f7f_0
|
26 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
27 |
+
- click=8.0.4=py37hecd8cb5_0
|
28 |
+
- cryptography=38.0.1=py37hf6deb26_0
|
29 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
30 |
+
- dataclasses=0.8=pyh6d0b6a4_7
|
31 |
+
- decorator=5.1.1=pyhd3eb1b0_0
|
32 |
+
- expat=2.4.9=he9d5cce_0
|
33 |
+
- ffmpeg=4.0=h01ea3c9_0
|
34 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
35 |
+
- fontconfig=2.14.1=hedf32ac_1
|
36 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
37 |
+
- freetype=2.12.1=hd8bbffd_0
|
38 |
+
- frozenlist=1.3.3=py37h6c40b1e_0
|
39 |
+
- gettext=0.21.0=h7535e17_0
|
40 |
+
- giflib=5.2.1=haf1e3a3_0
|
41 |
+
- glib=2.63.1=hd977a24_0
|
42 |
+
- google-auth=2.6.0=pyhd3eb1b0_0
|
43 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
44 |
+
- graphite2=1.3.14=he9d5cce_1
|
45 |
+
- grpcio=1.42.0=py37ha29bfda_0
|
46 |
+
- harfbuzz=1.8.8=hb8d4a28_0
|
47 |
+
- hdf5=1.10.2=hfa1e0ec_1
|
48 |
+
- icu=58.2=h0a44026_3
|
49 |
+
- idna=3.4=py37hecd8cb5_0
|
50 |
+
- intel-openmp=2021.4.0=hecd8cb5_3538
|
51 |
+
- ipython=7.31.1=py37hecd8cb5_1
|
52 |
+
- jasper=2.0.14=h0129ec2_2
|
53 |
+
- jedi=0.18.1=py37hecd8cb5_1
|
54 |
+
- jpeg=9e=hca72f7f_0
|
55 |
+
- kiwisolver=1.4.2=py37he9d5cce_0
|
56 |
+
- lcms2=2.12=hf1fd2bf_0
|
57 |
+
- lerc=3.0=he9d5cce_0
|
58 |
+
- libbrotlicommon=1.0.9=hca72f7f_7
|
59 |
+
- libbrotlidec=1.0.9=hca72f7f_7
|
60 |
+
- libbrotlienc=1.0.9=hca72f7f_7
|
61 |
+
- libcxx=14.0.6=h9765a3e_0
|
62 |
+
- libdeflate=1.8=h9ed2024_5
|
63 |
+
- libedit=3.1.20221030=h6c40b1e_0
|
64 |
+
- libffi=3.2.1=h0a44026_1007
|
65 |
+
- libgfortran=3.0.1=h93005f0_2
|
66 |
+
- libiconv=1.16=hca72f7f_2
|
67 |
+
- libopencv=3.4.2=h7c891bd_1
|
68 |
+
- libopus=1.3.1=h1de35cc_0
|
69 |
+
- libpng=1.6.37=ha441bb4_0
|
70 |
+
- libprotobuf=3.20.1=h8346a28_0
|
71 |
+
- libtiff=4.4.0=h2cd0358_2
|
72 |
+
- libvpx=1.7.0=h378b8a2_0
|
73 |
+
- libwebp=1.2.4=h56c3ce4_0
|
74 |
+
- libwebp-base=1.2.4=hca72f7f_0
|
75 |
+
- libxml2=2.9.14=hbf8cd5e_0
|
76 |
+
- llvm-openmp=14.0.6=h0dcd299_0
|
77 |
+
- lz4-c=1.9.4=hcec6c5f_0
|
78 |
+
- markdown=3.3.4=py37hecd8cb5_0
|
79 |
+
- matplotlib=3.1.2=py37h9aa3819_0
|
80 |
+
- matplotlib-inline=0.1.6=py37hecd8cb5_0
|
81 |
+
- mkl=2021.4.0=hecd8cb5_637
|
82 |
+
- mkl-service=2.4.0=py37h9ed2024_0
|
83 |
+
- mkl_fft=1.3.1=py37h4ab4a9b_0
|
84 |
+
- mkl_random=1.2.2=py37hb2f4e1b_0
|
85 |
+
- multidict=6.0.2=py37hca72f7f_0
|
86 |
+
- munkres=1.1.4=py_0
|
87 |
+
- ncurses=6.3=hca72f7f_3
|
88 |
+
- numpy=1.21.5=py37h2e5f0a9_3
|
89 |
+
- numpy-base=1.21.5=py37h3b1a694_3
|
90 |
+
- oauthlib=3.2.1=py37hecd8cb5_0
|
91 |
+
- olefile=0.46=py37_0
|
92 |
+
- opencv=3.4.2=py37h6fd60c2_1
|
93 |
+
- openssl=1.1.1s=hca72f7f_0
|
94 |
+
- packaging=21.3=pyhd3eb1b0_0
|
95 |
+
- parso=0.8.3=pyhd3eb1b0_0
|
96 |
+
- pcre=8.45=h23ab428_0
|
97 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
98 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
99 |
+
- pillow=6.1.0=py37hb68e598_0
|
100 |
+
- pip=22.3.1=py37hecd8cb5_0
|
101 |
+
- pixman=0.40.0=h9ed2024_1
|
102 |
+
- prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
103 |
+
- protobuf=3.20.1=py37he9d5cce_0
|
104 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
105 |
+
- py-opencv=3.4.2=py37h7c891bd_1
|
106 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
107 |
+
- pyasn1-modules=0.2.8=py_0
|
108 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
109 |
+
- pygments=2.11.2=pyhd3eb1b0_0
|
110 |
+
- pyjwt=2.4.0=py37hecd8cb5_0
|
111 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
112 |
+
- pyparsing=3.0.9=py37hecd8cb5_0
|
113 |
+
- pysocks=1.7.1=py37hecd8cb5_0
|
114 |
+
- python=3.7.3=h359304d_0
|
115 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
116 |
+
- pytorch=1.13.1=py3.7_0
|
117 |
+
- readline=7.0=h1de35cc_5
|
118 |
+
- requests=2.28.1=py37hecd8cb5_0
|
119 |
+
- requests-oauthlib=1.3.0=py_0
|
120 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
121 |
+
- setuptools=65.5.0=py37hecd8cb5_0
|
122 |
+
- six=1.16.0=pyhd3eb1b0_1
|
123 |
+
- sqlite=3.33.0=hffcf06c_0
|
124 |
+
- tensorboard=2.9.0=py37hecd8cb5_0
|
125 |
+
- tensorboard-data-server=0.6.1=py37h7242b5c_0
|
126 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
127 |
+
- tk=8.6.12=h5d9f67b_0
|
128 |
+
- torchvision=0.2.2=py_3
|
129 |
+
- tornado=6.2=py37hca72f7f_0
|
130 |
+
- tqdm=4.64.1=py37hecd8cb5_0
|
131 |
+
- traitlets=5.7.1=py37hecd8cb5_0
|
132 |
+
- typing-extensions=4.4.0=py37hecd8cb5_0
|
133 |
+
- typing_extensions=4.4.0=py37hecd8cb5_0
|
134 |
+
- urllib3=1.26.13=py37hecd8cb5_0
|
135 |
+
- wcwidth=0.2.5=pyhd3eb1b0_0
|
136 |
+
- werkzeug=2.0.3=pyhd3eb1b0_0
|
137 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
138 |
+
- xz=5.2.8=h6c40b1e_0
|
139 |
+
- yarl=1.8.1=py37hca72f7f_0
|
140 |
+
- zlib=1.2.13=h4dc903c_0
|
141 |
+
- zstd=1.5.2=hcb37349_0
|
142 |
+
- pip:
|
143 |
+
- ale-py==0.7.5
|
144 |
+
- cloudpickle==2.2.0
|
145 |
+
- gym==0.21.0
|
146 |
+
- gym-notices==0.0.8
|
147 |
+
- gym-retro==0.8.0
|
148 |
+
- importlib-metadata==4.13.0
|
149 |
+
- importlib-resources==5.10.1
|
150 |
+
- pygame==2.1.0
|
151 |
+
- pyglet==1.5.27
|
152 |
+
- zipp==3.11.0
|
153 |
+
prefix: /Users/karkisushant/miniconda3/envs/mlrl
|
environment.procgen-v2.yml
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: procgen
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- absl-py=1.3.0=py39hecd8cb5_0
|
7 |
+
- aiohttp=3.8.3=py39h6c40b1e_0
|
8 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
9 |
+
- async-timeout=4.0.2=py39hecd8cb5_0
|
10 |
+
- attrs=22.1.0=py39hecd8cb5_0
|
11 |
+
- blas=1.0=mkl
|
12 |
+
- blinker=1.4=py39hecd8cb5_0
|
13 |
+
- brotli=1.0.9=hca72f7f_7
|
14 |
+
- brotli-bin=1.0.9=hca72f7f_7
|
15 |
+
- brotlipy=0.7.0=py39h9ed2024_1003
|
16 |
+
- bzip2=1.0.8=h1de35cc_0
|
17 |
+
- c-ares=1.18.1=hca72f7f_0
|
18 |
+
- ca-certificates=2022.10.11=hecd8cb5_0
|
19 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
20 |
+
- certifi=2022.9.24=py39hecd8cb5_0
|
21 |
+
- cffi=1.15.1=py39h6c40b1e_3
|
22 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
23 |
+
- click=8.0.4=py39hecd8cb5_0
|
24 |
+
- contourpy=1.0.5=py39haf03e11_0
|
25 |
+
- cryptography=38.0.1=py39hf6deb26_0
|
26 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
27 |
+
- ffmpeg=4.3=h0a44026_0
|
28 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
29 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
30 |
+
- freetype=2.12.1=hd8bbffd_0
|
31 |
+
- frozenlist=1.3.3=py39h6c40b1e_0
|
32 |
+
- gettext=0.21.0=h7535e17_0
|
33 |
+
- giflib=5.2.1=haf1e3a3_0
|
34 |
+
- gmp=6.2.1=he9d5cce_3
|
35 |
+
- gnutls=3.6.15=hed9c0bf_0
|
36 |
+
- google-auth=2.6.0=pyhd3eb1b0_0
|
37 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
38 |
+
- grpcio=1.42.0=py39ha29bfda_0
|
39 |
+
- icu=58.2=h0a44026_3
|
40 |
+
- idna=3.4=py39hecd8cb5_0
|
41 |
+
- importlib-metadata=4.11.3=py39hecd8cb5_0
|
42 |
+
- intel-openmp=2021.4.0=hecd8cb5_3538
|
43 |
+
- jpeg=9e=hca72f7f_0
|
44 |
+
- kiwisolver=1.4.2=py39he9d5cce_0
|
45 |
+
- lame=3.100=h1de35cc_0
|
46 |
+
- lcms2=2.12=hf1fd2bf_0
|
47 |
+
- lerc=3.0=he9d5cce_0
|
48 |
+
- libbrotlicommon=1.0.9=hca72f7f_7
|
49 |
+
- libbrotlidec=1.0.9=hca72f7f_7
|
50 |
+
- libbrotlienc=1.0.9=hca72f7f_7
|
51 |
+
- libcxx=14.0.6=h9765a3e_0
|
52 |
+
- libdeflate=1.8=h9ed2024_5
|
53 |
+
- libffi=3.4.2=hecd8cb5_6
|
54 |
+
- libiconv=1.16=hca72f7f_2
|
55 |
+
- libidn2=2.3.2=h9ed2024_0
|
56 |
+
- libpng=1.6.37=ha441bb4_0
|
57 |
+
- libprotobuf=3.20.1=h8346a28_0
|
58 |
+
- libtasn1=4.16.0=h9ed2024_0
|
59 |
+
- libtiff=4.4.0=h2cd0358_2
|
60 |
+
- libunistring=0.9.10=h9ed2024_0
|
61 |
+
- libwebp=1.2.4=h56c3ce4_0
|
62 |
+
- libwebp-base=1.2.4=hca72f7f_0
|
63 |
+
- libxml2=2.9.14=hbf8cd5e_0
|
64 |
+
- llvm-openmp=14.0.6=h0dcd299_0
|
65 |
+
- lz4-c=1.9.4=hcec6c5f_0
|
66 |
+
- markdown=3.3.4=py39hecd8cb5_0
|
67 |
+
- markupsafe=2.1.1=py39hca72f7f_0
|
68 |
+
- matplotlib=3.6.2=py39hecd8cb5_0
|
69 |
+
- matplotlib-base=3.6.2=py39h220de94_0
|
70 |
+
- mkl=2021.4.0=hecd8cb5_637
|
71 |
+
- mkl-service=2.4.0=py39h9ed2024_0
|
72 |
+
- mkl_fft=1.3.1=py39h4ab4a9b_0
|
73 |
+
- mkl_random=1.2.2=py39hb2f4e1b_0
|
74 |
+
- multidict=6.0.2=py39hca72f7f_0
|
75 |
+
- munkres=1.1.4=py_0
|
76 |
+
- ncurses=6.3=hca72f7f_3
|
77 |
+
- nettle=3.7.3=h230ac6f_1
|
78 |
+
- numpy=1.23.4=py39he696674_0
|
79 |
+
- numpy-base=1.23.4=py39h9cd3388_0
|
80 |
+
- oauthlib=3.2.1=py39hecd8cb5_0
|
81 |
+
- openh264=2.1.1=h8346a28_0
|
82 |
+
- openssl=1.1.1s=hca72f7f_0
|
83 |
+
- packaging=21.3=pyhd3eb1b0_0
|
84 |
+
- pillow=9.2.0=py39hde71d04_1
|
85 |
+
- pip=22.3.1=py39hecd8cb5_0
|
86 |
+
- protobuf=3.20.1=py39he9d5cce_0
|
87 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
88 |
+
- pyasn1-modules=0.2.8=py_0
|
89 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
90 |
+
- pyjwt=2.4.0=py39hecd8cb5_0
|
91 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
92 |
+
- pyparsing=3.0.9=py39hecd8cb5_0
|
93 |
+
- pysocks=1.7.1=py39hecd8cb5_0
|
94 |
+
- python=3.9.15=h218abb5_2
|
95 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
96 |
+
- pytorch=1.13.1=py3.9_0
|
97 |
+
- readline=8.2=hca72f7f_0
|
98 |
+
- requests=2.28.1=py39hecd8cb5_0
|
99 |
+
- requests-oauthlib=1.3.0=py_0
|
100 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
101 |
+
- setuptools=65.5.0=py39hecd8cb5_0
|
102 |
+
- six=1.16.0=pyhd3eb1b0_1
|
103 |
+
- sqlite=3.40.0=h880c91c_0
|
104 |
+
- tensorboard=2.9.0=py39hecd8cb5_0
|
105 |
+
- tensorboard-data-server=0.6.1=py39h7242b5c_0
|
106 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
107 |
+
- tk=8.6.12=h5d9f67b_0
|
108 |
+
- torchvision=0.14.1=py39_cpu
|
109 |
+
- tornado=6.2=py39hca72f7f_0
|
110 |
+
- tqdm=4.64.1=py39hecd8cb5_0
|
111 |
+
- typing_extensions=4.4.0=py39hecd8cb5_0
|
112 |
+
- tzdata=2022g=h04d1e81_0
|
113 |
+
- urllib3=1.26.13=py39hecd8cb5_0
|
114 |
+
- werkzeug=2.2.2=py39hecd8cb5_0
|
115 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
116 |
+
- xz=5.2.8=h6c40b1e_0
|
117 |
+
- yarl=1.8.1=py39hca72f7f_0
|
118 |
+
- zipp=3.8.0=py39hecd8cb5_0
|
119 |
+
- zlib=1.2.13=h4dc903c_0
|
120 |
+
- zstd=1.5.2=hcb37349_0
|
121 |
+
- pip:
|
122 |
+
- cloudpickle==2.2.0
|
123 |
+
- filelock==3.8.2
|
124 |
+
- glcontext==2.3.7
|
125 |
+
- glfw==1.12.0
|
126 |
+
- gym==0.21.0
|
127 |
+
- gym-notices==0.0.8
|
128 |
+
- gym3==0.3.3
|
129 |
+
- imageio==2.22.4
|
130 |
+
- imageio-ffmpeg==0.3.0
|
131 |
+
- moderngl==5.7.4
|
132 |
+
- opencv-python==4.6.0.66
|
133 |
+
- procgen==0.10.7
|
134 |
+
- pyglet==1.5.27
|
135 |
+
prefix: /Users/karkisushant/miniconda3/envs/v2
|
environment.procgen.yml
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: procgen
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- absl-py=1.3.0=py39hecd8cb5_0
|
7 |
+
- aiohttp=3.8.3=py39h6c40b1e_0
|
8 |
+
- aiosignal=1.2.0=pyhd3eb1b0_0
|
9 |
+
- async-timeout=4.0.2=py39hecd8cb5_0
|
10 |
+
- attrs=22.1.0=py39hecd8cb5_0
|
11 |
+
- blas=1.0=mkl
|
12 |
+
- blinker=1.4=py39hecd8cb5_0
|
13 |
+
- brotli=1.0.9=hca72f7f_7
|
14 |
+
- brotli-bin=1.0.9=hca72f7f_7
|
15 |
+
- brotlipy=0.7.0=py39h9ed2024_1003
|
16 |
+
- bzip2=1.0.8=h1de35cc_0
|
17 |
+
- c-ares=1.18.1=hca72f7f_0
|
18 |
+
- ca-certificates=2022.10.11=hecd8cb5_0
|
19 |
+
- cachetools=4.2.2=pyhd3eb1b0_0
|
20 |
+
- certifi=2022.9.24=py39hecd8cb5_0
|
21 |
+
- cffi=1.15.1=py39h6c40b1e_3
|
22 |
+
- charset-normalizer=2.0.4=pyhd3eb1b0_0
|
23 |
+
- click=8.0.4=py39hecd8cb5_0
|
24 |
+
- contourpy=1.0.5=py39haf03e11_0
|
25 |
+
- cryptography=38.0.1=py39hf6deb26_0
|
26 |
+
- cycler=0.11.0=pyhd3eb1b0_0
|
27 |
+
- ffmpeg=4.3=h0a44026_0
|
28 |
+
- flit-core=3.6.0=pyhd3eb1b0_0
|
29 |
+
- fonttools=4.25.0=pyhd3eb1b0_0
|
30 |
+
- freetype=2.12.1=hd8bbffd_0
|
31 |
+
- frozenlist=1.3.3=py39h6c40b1e_0
|
32 |
+
- gettext=0.21.0=h7535e17_0
|
33 |
+
- giflib=5.2.1=haf1e3a3_0
|
34 |
+
- gmp=6.2.1=he9d5cce_3
|
35 |
+
- gnutls=3.6.15=hed9c0bf_0
|
36 |
+
- google-auth=2.6.0=pyhd3eb1b0_0
|
37 |
+
- google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
38 |
+
- grpcio=1.42.0=py39ha29bfda_0
|
39 |
+
- icu=58.2=h0a44026_3
|
40 |
+
- idna=3.4=py39hecd8cb5_0
|
41 |
+
- importlib-metadata=4.11.3=py39hecd8cb5_0
|
42 |
+
- intel-openmp=2021.4.0=hecd8cb5_3538
|
43 |
+
- jpeg=9e=hca72f7f_0
|
44 |
+
- kiwisolver=1.4.2=py39he9d5cce_0
|
45 |
+
- lame=3.100=h1de35cc_0
|
46 |
+
- lcms2=2.12=hf1fd2bf_0
|
47 |
+
- lerc=3.0=he9d5cce_0
|
48 |
+
- libbrotlicommon=1.0.9=hca72f7f_7
|
49 |
+
- libbrotlidec=1.0.9=hca72f7f_7
|
50 |
+
- libbrotlienc=1.0.9=hca72f7f_7
|
51 |
+
- libcxx=14.0.6=h9765a3e_0
|
52 |
+
- libdeflate=1.8=h9ed2024_5
|
53 |
+
- libffi=3.4.2=hecd8cb5_6
|
54 |
+
- libiconv=1.16=hca72f7f_2
|
55 |
+
- libidn2=2.3.2=h9ed2024_0
|
56 |
+
- libpng=1.6.37=ha441bb4_0
|
57 |
+
- libprotobuf=3.20.1=h8346a28_0
|
58 |
+
- libtasn1=4.16.0=h9ed2024_0
|
59 |
+
- libtiff=4.4.0=h2cd0358_2
|
60 |
+
- libunistring=0.9.10=h9ed2024_0
|
61 |
+
- libwebp=1.2.4=h56c3ce4_0
|
62 |
+
- libwebp-base=1.2.4=hca72f7f_0
|
63 |
+
- libxml2=2.9.14=hbf8cd5e_0
|
64 |
+
- llvm-openmp=14.0.6=h0dcd299_0
|
65 |
+
- lz4-c=1.9.4=hcec6c5f_0
|
66 |
+
- markdown=3.3.4=py39hecd8cb5_0
|
67 |
+
- markupsafe=2.1.1=py39hca72f7f_0
|
68 |
+
- matplotlib=3.6.2=py39hecd8cb5_0
|
69 |
+
- matplotlib-base=3.6.2=py39h220de94_0
|
70 |
+
- mkl=2021.4.0=hecd8cb5_637
|
71 |
+
- mkl-service=2.4.0=py39h9ed2024_0
|
72 |
+
- mkl_fft=1.3.1=py39h4ab4a9b_0
|
73 |
+
- mkl_random=1.2.2=py39hb2f4e1b_0
|
74 |
+
- multidict=6.0.2=py39hca72f7f_0
|
75 |
+
- munkres=1.1.4=py_0
|
76 |
+
- ncurses=6.3=hca72f7f_3
|
77 |
+
- nettle=3.7.3=h230ac6f_1
|
78 |
+
- numpy=1.23.4=py39he696674_0
|
79 |
+
- numpy-base=1.23.4=py39h9cd3388_0
|
80 |
+
- oauthlib=3.2.1=py39hecd8cb5_0
|
81 |
+
- openh264=2.1.1=h8346a28_0
|
82 |
+
- openssl=1.1.1s=hca72f7f_0
|
83 |
+
- packaging=21.3=pyhd3eb1b0_0
|
84 |
+
- pillow=9.2.0=py39hde71d04_1
|
85 |
+
- pip=22.3.1=py39hecd8cb5_0
|
86 |
+
- protobuf=3.20.1=py39he9d5cce_0
|
87 |
+
- pyasn1=0.4.8=pyhd3eb1b0_0
|
88 |
+
- pyasn1-modules=0.2.8=py_0
|
89 |
+
- pycparser=2.21=pyhd3eb1b0_0
|
90 |
+
- pyjwt=2.4.0=py39hecd8cb5_0
|
91 |
+
- pyopenssl=22.0.0=pyhd3eb1b0_0
|
92 |
+
- pyparsing=3.0.9=py39hecd8cb5_0
|
93 |
+
- pysocks=1.7.1=py39hecd8cb5_0
|
94 |
+
- python=3.9.15=h218abb5_2
|
95 |
+
- python-dateutil=2.8.2=pyhd3eb1b0_0
|
96 |
+
- pytorch=1.13.1=py3.9_0
|
97 |
+
- readline=8.2=hca72f7f_0
|
98 |
+
- requests=2.28.1=py39hecd8cb5_0
|
99 |
+
- requests-oauthlib=1.3.0=py_0
|
100 |
+
- rsa=4.7.2=pyhd3eb1b0_1
|
101 |
+
- setuptools=65.5.0=py39hecd8cb5_0
|
102 |
+
- six=1.16.0=pyhd3eb1b0_1
|
103 |
+
- sqlite=3.40.0=h880c91c_0
|
104 |
+
- tensorboard=2.9.0=py39hecd8cb5_0
|
105 |
+
- tensorboard-data-server=0.6.1=py39h7242b5c_0
|
106 |
+
- tensorboard-plugin-wit=1.6.0=py_0
|
107 |
+
- tk=8.6.12=h5d9f67b_0
|
108 |
+
- torchvision=0.14.1=py39_cpu
|
109 |
+
- tornado=6.2=py39hca72f7f_0
|
110 |
+
- tqdm=4.64.1=py39hecd8cb5_0
|
111 |
+
- typing_extensions=4.4.0=py39hecd8cb5_0
|
112 |
+
- tzdata=2022g=h04d1e81_0
|
113 |
+
- urllib3=1.26.13=py39hecd8cb5_0
|
114 |
+
- werkzeug=2.2.2=py39hecd8cb5_0
|
115 |
+
- wheel=0.37.1=pyhd3eb1b0_0
|
116 |
+
- xz=5.2.8=h6c40b1e_0
|
117 |
+
- yarl=1.8.1=py39hca72f7f_0
|
118 |
+
- zipp=3.8.0=py39hecd8cb5_0
|
119 |
+
- zlib=1.2.13=h4dc903c_0
|
120 |
+
- zstd=1.5.2=hcb37349_0
|
121 |
+
- pip:
|
122 |
+
- cloudpickle==2.2.0
|
123 |
+
- filelock==3.8.2
|
124 |
+
- glcontext==2.3.7
|
125 |
+
- glfw==1.12.0
|
126 |
+
- gym==0.21.0
|
127 |
+
- gym-notices==0.0.8
|
128 |
+
- gym3==0.3.3
|
129 |
+
- imageio==2.22.4
|
130 |
+
- imageio-ffmpeg==0.3.0
|
131 |
+
- moderngl==5.7.4
|
132 |
+
- opencv-python==4.6.0.66
|
133 |
+
- procgen==0.10.7
|
134 |
+
- pyglet==1.5.27
|
135 |
+
prefix: /Users/karkisushant/miniconda3/envs/procgen
|
requirements-v1.txt
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.3.0
|
2 |
+
ale-py==0.7.5
|
3 |
+
astunparse==1.6.3
|
4 |
+
attrs==22.1.0
|
5 |
+
box2d-py==2.3.5
|
6 |
+
cachetools==5.2.0
|
7 |
+
certifi==2022.12.7
|
8 |
+
cffi==1.15.1
|
9 |
+
charset-normalizer==2.1.1
|
10 |
+
cloudpickle==2.2.0
|
11 |
+
cycler==0.11.0
|
12 |
+
Cython==0.29.32
|
13 |
+
fasteners==0.18
|
14 |
+
flatbuffers==22.12.6
|
15 |
+
fonttools==4.38.0
|
16 |
+
future==0.18.2
|
17 |
+
gast==0.4.0
|
18 |
+
glfw==2.5.5
|
19 |
+
google-auth==2.15.0
|
20 |
+
google-auth-oauthlib==0.4.6
|
21 |
+
google-pasta==0.2.0
|
22 |
+
grpcio==1.51.1
|
23 |
+
gym==0.21.0
|
24 |
+
gym-notices==0.0.8
|
25 |
+
gym-retro==0.8.0
|
26 |
+
h5py==3.7.0
|
27 |
+
idna==3.4
|
28 |
+
imageio==2.22.4
|
29 |
+
importlib-metadata==4.13.0
|
30 |
+
importlib-resources==5.10.1
|
31 |
+
iniconfig==1.1.1
|
32 |
+
keras==2.11.0
|
33 |
+
kiwisolver==1.4.4
|
34 |
+
libclang==14.0.6
|
35 |
+
lz4==4.0.2
|
36 |
+
Markdown==3.4.1
|
37 |
+
MarkupSafe==2.1.1
|
38 |
+
matplotlib==3.5.3
|
39 |
+
mujoco==2.2.0
|
40 |
+
mujoco-py==2.1.2.14
|
41 |
+
numpy==1.21.6
|
42 |
+
oauthlib==3.2.2
|
43 |
+
opencv-python==4.6.0.66
|
44 |
+
opt-einsum==3.3.0
|
45 |
+
packaging==22.0
|
46 |
+
Pillow==9.3.0
|
47 |
+
pluggy==1.0.0
|
48 |
+
protobuf==3.19.6
|
49 |
+
py==1.11.0
|
50 |
+
pyasn1==0.4.8
|
51 |
+
pyasn1-modules==0.2.8
|
52 |
+
pycparser==2.21
|
53 |
+
pygame==2.1.0
|
54 |
+
pyglet==1.5.11
|
55 |
+
PyOpenGL==3.1.6
|
56 |
+
pyparsing==3.0.9
|
57 |
+
pytest==7.0.1
|
58 |
+
python-dateutil==2.8.2
|
59 |
+
requests==2.28.1
|
60 |
+
requests-oauthlib==1.3.1
|
61 |
+
rsa==4.9
|
62 |
+
six==1.16.0
|
63 |
+
swig==4.1.1
|
64 |
+
tensorboard==2.11.0
|
65 |
+
tensorboard-data-server==0.6.1
|
66 |
+
tensorboard-plugin-wit==1.8.1
|
67 |
+
tensorflow==2.11.0
|
68 |
+
tensorflow-estimator==2.11.0
|
69 |
+
tensorflow-io-gcs-filesystem==0.28.0
|
70 |
+
termcolor==2.1.1
|
71 |
+
tomli==2.0.1
|
72 |
+
typing_extensions==4.4.0
|
73 |
+
urllib3==1.26.13
|
74 |
+
Werkzeug==2.2.2
|
75 |
+
wrapt==1.14.1
|
76 |
+
zipp==3.11.0
|
requirements.txt
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.3.0
|
2 |
+
ale-py==0.7.5
|
3 |
+
attrs==22.1.0
|
4 |
+
box2d-py==2.3.5
|
5 |
+
cffi==1.15.1
|
6 |
+
cloudpickle==2.2.0
|
7 |
+
cycler==0.11.0
|
8 |
+
Cython==0.29.32
|
9 |
+
fasteners==0.18
|
10 |
+
fonttools==4.38.0
|
11 |
+
future==0.18.2
|
12 |
+
glfw==2.5.5
|
13 |
+
gym==0.21.0
|
14 |
+
gym-notices==0.0.8
|
15 |
+
gym-retro==0.8.0
|
16 |
+
imageio==2.22.4
|
17 |
+
importlib-metadata==4.13.0
|
18 |
+
importlib-resources==5.10.1
|
19 |
+
iniconfig==1.1.1
|
20 |
+
kiwisolver==1.4.4
|
21 |
+
lz4==4.0.2
|
22 |
+
matplotlib==3.5.3
|
23 |
+
mujoco==2.2.0
|
24 |
+
mujoco-py==2.1.2.14
|
25 |
+
numpy==1.18.0
|
26 |
+
opencv-python==4.6.0.66
|
27 |
+
packaging==22.0
|
28 |
+
Pillow==9.3.0
|
29 |
+
pluggy==1.0.0
|
30 |
+
py==1.11.0
|
31 |
+
pycparser==2.21
|
32 |
+
pygame==2.1.0
|
33 |
+
pyglet==1.5.11
|
34 |
+
PyOpenGL==3.1.6
|
35 |
+
pyparsing==3.0.9
|
36 |
+
pytest==7.0.1
|
37 |
+
python-dateutil==2.8.2
|
38 |
+
six==1.16.0
|
39 |
+
swig==4.1.1
|
40 |
+
tomli==2.0.1
|
41 |
+
typing_extensions==4.4.0
|
42 |
+
zipp==3.11.0
|
src/airstriker-genesis/__init__.py
ADDED
File without changes
|
src/airstriker-genesis/agent.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch.nn as nn
|
5 |
+
import copy
|
6 |
+
import time, datetime
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from collections import deque
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
import pickle
|
11 |
+
|
12 |
+
|
13 |
+
class DQNet(nn.Module):
|
14 |
+
"""mini cnn structure
|
15 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, input_dim, output_dim):
|
19 |
+
super().__init__()
|
20 |
+
print("#################################")
|
21 |
+
print("#################################")
|
22 |
+
print(input_dim)
|
23 |
+
print(output_dim)
|
24 |
+
print("#################################")
|
25 |
+
print("#################################")
|
26 |
+
c, h, w = input_dim
|
27 |
+
|
28 |
+
# if h != 84:
|
29 |
+
# raise ValueError(f"Expecting input height: 84, got: {h}")
|
30 |
+
# if w != 84:
|
31 |
+
# raise ValueError(f"Expecting input width: 84, got: {w}")
|
32 |
+
|
33 |
+
self.online = nn.Sequential(
|
34 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
37 |
+
nn.ReLU(),
|
38 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
39 |
+
nn.ReLU(),
|
40 |
+
nn.Flatten(),
|
41 |
+
nn.Linear(17024, 512),
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.Linear(512, output_dim),
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
self.target = copy.deepcopy(self.online)
|
48 |
+
|
49 |
+
# Q_target parameters are frozen.
|
50 |
+
for p in self.target.parameters():
|
51 |
+
p.requires_grad = False
|
52 |
+
|
53 |
+
def forward(self, input, model):
|
54 |
+
if model == "online":
|
55 |
+
return self.online(input)
|
56 |
+
elif model == "target":
|
57 |
+
return self.target(input)
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
class MetricLogger:
|
62 |
+
def __init__(self, save_dir):
|
63 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
64 |
+
self.save_log = save_dir / "log"
|
65 |
+
with open(self.save_log, "w") as f:
|
66 |
+
f.write(
|
67 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
68 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
69 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
70 |
+
)
|
71 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
72 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
73 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
74 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
75 |
+
|
76 |
+
# History metrics
|
77 |
+
self.ep_rewards = []
|
78 |
+
self.ep_lengths = []
|
79 |
+
self.ep_avg_losses = []
|
80 |
+
self.ep_avg_qs = []
|
81 |
+
|
82 |
+
# Moving averages, added for every call to record()
|
83 |
+
self.moving_avg_ep_rewards = []
|
84 |
+
self.moving_avg_ep_lengths = []
|
85 |
+
self.moving_avg_ep_avg_losses = []
|
86 |
+
self.moving_avg_ep_avg_qs = []
|
87 |
+
|
88 |
+
# Current episode metric
|
89 |
+
self.init_episode()
|
90 |
+
|
91 |
+
# Timing
|
92 |
+
self.record_time = time.time()
|
93 |
+
|
94 |
+
def log_step(self, reward, loss, q):
|
95 |
+
self.curr_ep_reward += reward
|
96 |
+
self.curr_ep_length += 1
|
97 |
+
if loss:
|
98 |
+
self.curr_ep_loss += loss
|
99 |
+
self.curr_ep_q += q
|
100 |
+
self.curr_ep_loss_length += 1
|
101 |
+
|
102 |
+
def log_episode(self, episode_number):
|
103 |
+
"Mark end of episode"
|
104 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
105 |
+
self.ep_lengths.append(self.curr_ep_length)
|
106 |
+
if self.curr_ep_loss_length == 0:
|
107 |
+
ep_avg_loss = 0
|
108 |
+
ep_avg_q = 0
|
109 |
+
else:
|
110 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
111 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
112 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
113 |
+
self.ep_avg_qs.append(ep_avg_q)
|
114 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
115 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
116 |
+
self.writer.flush()
|
117 |
+
self.init_episode()
|
118 |
+
|
119 |
+
def init_episode(self):
|
120 |
+
self.curr_ep_reward = 0.0
|
121 |
+
self.curr_ep_length = 0
|
122 |
+
self.curr_ep_loss = 0.0
|
123 |
+
self.curr_ep_q = 0.0
|
124 |
+
self.curr_ep_loss_length = 0
|
125 |
+
|
126 |
+
def record(self, episode, epsilon, step):
|
127 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
128 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
129 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
130 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
131 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
132 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
133 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
134 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
135 |
+
|
136 |
+
last_record_time = self.record_time
|
137 |
+
self.record_time = time.time()
|
138 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
139 |
+
|
140 |
+
print(
|
141 |
+
f"Episode {episode} - "
|
142 |
+
f"Step {step} - "
|
143 |
+
f"Epsilon {epsilon} - "
|
144 |
+
f"Mean Reward {mean_ep_reward} - "
|
145 |
+
f"Mean Length {mean_ep_length} - "
|
146 |
+
f"Mean Loss {mean_ep_loss} - "
|
147 |
+
f"Mean Q Value {mean_ep_q} - "
|
148 |
+
f"Time Delta {time_since_last_record} - "
|
149 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
150 |
+
)
|
151 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
152 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
153 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
154 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
155 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
156 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
157 |
+
self.writer.flush()
|
158 |
+
with open(self.save_log, "a") as f:
|
159 |
+
f.write(
|
160 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
161 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
162 |
+
f"{time_since_last_record:15.3f}"
|
163 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
164 |
+
)
|
165 |
+
|
166 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
167 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
168 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
169 |
+
plt.clf()
|
170 |
+
|
171 |
+
|
172 |
+
class DQNAgent:
|
173 |
+
def __init__(self,
|
174 |
+
state_dim,
|
175 |
+
action_dim,
|
176 |
+
save_dir,
|
177 |
+
checkpoint=None,
|
178 |
+
learning_rate=0.00025,
|
179 |
+
max_memory_size=100000,
|
180 |
+
batch_size=32,
|
181 |
+
exploration_rate=1,
|
182 |
+
exploration_rate_decay=0.9999999,
|
183 |
+
exploration_rate_min=0.1,
|
184 |
+
training_frequency=1,
|
185 |
+
learning_starts=1000,
|
186 |
+
target_network_sync_frequency=500,
|
187 |
+
reset_exploration_rate=False,
|
188 |
+
save_frequency=100000,
|
189 |
+
gamma=0.9,
|
190 |
+
load_replay_buffer=True):
|
191 |
+
self.state_dim = state_dim
|
192 |
+
self.action_dim = action_dim
|
193 |
+
self.max_memory_size = max_memory_size
|
194 |
+
self.memory = deque(maxlen=max_memory_size)
|
195 |
+
self.batch_size = batch_size
|
196 |
+
|
197 |
+
self.exploration_rate = exploration_rate
|
198 |
+
self.exploration_rate_decay = exploration_rate_decay
|
199 |
+
self.exploration_rate_min = exploration_rate_min
|
200 |
+
self.gamma = gamma
|
201 |
+
|
202 |
+
self.curr_step = 0
|
203 |
+
self.learning_starts = learning_starts # min. experiences before training
|
204 |
+
|
205 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
206 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
207 |
+
|
208 |
+
self.save_every = save_frequency # no. of experiences between saving Mario Net
|
209 |
+
self.save_dir = save_dir
|
210 |
+
|
211 |
+
self.use_cuda = torch.cuda.is_available()
|
212 |
+
|
213 |
+
# Mario's DNN to predict the most optimal action - we implement this in the Learn section
|
214 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
215 |
+
if self.use_cuda:
|
216 |
+
self.net = self.net.to(device='cuda')
|
217 |
+
if checkpoint:
|
218 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
219 |
+
|
220 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
221 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
222 |
+
|
223 |
+
|
224 |
+
def act(self, state):
|
225 |
+
"""
|
226 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
227 |
+
|
228 |
+
Inputs:
|
229 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
230 |
+
Outputs:
|
231 |
+
action_idx (int): An integer representing which action Mario will perform
|
232 |
+
"""
|
233 |
+
# EXPLORE
|
234 |
+
if np.random.rand() < self.exploration_rate:
|
235 |
+
action_idx = np.random.randint(self.action_dim)
|
236 |
+
|
237 |
+
# EXPLOIT
|
238 |
+
else:
|
239 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
240 |
+
state = state.unsqueeze(0)
|
241 |
+
action_values = self.net(state, model='online')
|
242 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
243 |
+
|
244 |
+
# decrease exploration_rate
|
245 |
+
self.exploration_rate *= self.exploration_rate_decay
|
246 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
247 |
+
|
248 |
+
# increment step
|
249 |
+
self.curr_step += 1
|
250 |
+
return action_idx
|
251 |
+
|
252 |
+
def cache(self, state, next_state, action, reward, done):
|
253 |
+
"""
|
254 |
+
Store the experience to self.memory (replay buffer)
|
255 |
+
|
256 |
+
Inputs:
|
257 |
+
state (LazyFrame),
|
258 |
+
next_state (LazyFrame),
|
259 |
+
action (int),
|
260 |
+
reward (float),
|
261 |
+
done(bool))
|
262 |
+
"""
|
263 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
264 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
265 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
266 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
267 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
268 |
+
|
269 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
270 |
+
|
271 |
+
|
272 |
+
def recall(self):
|
273 |
+
"""
|
274 |
+
Retrieve a batch of experiences from memory
|
275 |
+
"""
|
276 |
+
batch = random.sample(self.memory, self.batch_size)
|
277 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
278 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
279 |
+
|
280 |
+
|
281 |
+
# def td_estimate(self, state, action):
|
282 |
+
# current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
|
283 |
+
# return current_Q
|
284 |
+
|
285 |
+
|
286 |
+
# @torch.no_grad()
|
287 |
+
# def td_target(self, reward, next_state, done):
|
288 |
+
# next_state_Q = self.net(next_state, model='online')
|
289 |
+
# best_action = torch.argmax(next_state_Q, axis=1)
|
290 |
+
# next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
|
291 |
+
# return (reward + (1 - done.float()) * self.gamma * next_Q).float()
|
292 |
+
|
293 |
+
def td_estimate(self, states, actions):
|
294 |
+
actions = actions.reshape(-1, 1)
|
295 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
296 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
297 |
+
return predicted_qs
|
298 |
+
|
299 |
+
|
300 |
+
@torch.no_grad()
|
301 |
+
def td_target(self, rewards, next_states, dones):
|
302 |
+
rewards = rewards.reshape(-1, 1)
|
303 |
+
dones = dones.reshape(-1, 1)
|
304 |
+
target_qs = self.net(next_states, model='target')
|
305 |
+
target_qs = torch.max(target_qs, dim=1).values
|
306 |
+
target_qs = target_qs.reshape(-1, 1)
|
307 |
+
target_qs[dones] = 0.0
|
308 |
+
return (rewards + (self.gamma * target_qs))
|
309 |
+
|
310 |
+
def update_Q_online(self, td_estimate, td_target) :
|
311 |
+
loss = self.loss_fn(td_estimate, td_target)
|
312 |
+
self.optimizer.zero_grad()
|
313 |
+
loss.backward()
|
314 |
+
self.optimizer.step()
|
315 |
+
return loss.item()
|
316 |
+
|
317 |
+
|
318 |
+
def sync_Q_target(self):
|
319 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
320 |
+
|
321 |
+
|
322 |
+
def learn(self):
|
323 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
324 |
+
self.sync_Q_target()
|
325 |
+
|
326 |
+
if self.curr_step % self.save_every == 0:
|
327 |
+
self.save()
|
328 |
+
|
329 |
+
if self.curr_step < self.learning_starts:
|
330 |
+
return None, None
|
331 |
+
|
332 |
+
if self.curr_step % self.training_frequency != 0:
|
333 |
+
return None, None
|
334 |
+
|
335 |
+
# Sample from memory
|
336 |
+
state, next_state, action, reward, done = self.recall()
|
337 |
+
|
338 |
+
# Get TD Estimate
|
339 |
+
td_est = self.td_estimate(state, action)
|
340 |
+
|
341 |
+
# Get TD Target
|
342 |
+
td_tgt = self.td_target(reward, next_state, done)
|
343 |
+
|
344 |
+
# Backpropagate loss through Q_online
|
345 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
346 |
+
|
347 |
+
return (td_est.mean().item(), loss)
|
348 |
+
|
349 |
+
|
350 |
+
def save(self):
|
351 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
352 |
+
torch.save(
|
353 |
+
dict(
|
354 |
+
model=self.net.state_dict(),
|
355 |
+
exploration_rate=self.exploration_rate,
|
356 |
+
replay_memory=self.memory
|
357 |
+
),
|
358 |
+
save_path
|
359 |
+
)
|
360 |
+
|
361 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
362 |
+
|
363 |
+
|
364 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
365 |
+
if not load_path.exists():
|
366 |
+
raise ValueError(f"{load_path} does not exist")
|
367 |
+
|
368 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
369 |
+
exploration_rate = ckp.get('exploration_rate')
|
370 |
+
state_dict = ckp.get('model')
|
371 |
+
|
372 |
+
|
373 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
374 |
+
self.net.load_state_dict(state_dict)
|
375 |
+
|
376 |
+
if load_replay_buffer:
|
377 |
+
replay_memory = ckp.get('replay_memory')
|
378 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
379 |
+
self.memory = replay_memory if replay_memory else self.memory
|
380 |
+
|
381 |
+
if reset_exploration_rate:
|
382 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
383 |
+
else:
|
384 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
385 |
+
self.exploration_rate = exploration_rate
|
386 |
+
|
387 |
+
|
388 |
+
class DDQNAgent(DQNAgent):
|
389 |
+
@torch.no_grad()
|
390 |
+
def td_target(self, rewards, next_states, dones):
|
391 |
+
print("Double dqn -----------------------")
|
392 |
+
rewards = rewards.reshape(-1, 1)
|
393 |
+
dones = dones.reshape(-1, 1)
|
394 |
+
q_vals = self.net(next_states, model='online')
|
395 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
396 |
+
target_actions = target_actions.reshape(-1, 1)
|
397 |
+
target_qs = self.net(next_states, model='target').gather(target_actions, 1)
|
398 |
+
target_qs = target_qs.reshape(-1, 1)
|
399 |
+
target_qs[dones] = 0.0
|
400 |
+
return (rewards + (self.gamma * target_qs))
|
src/airstriker-genesis/cartpole.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch.nn as nn
|
6 |
+
import copy
|
7 |
+
import time, datetime
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
from collections import deque
|
10 |
+
from torch.utils.tensorboard import SummaryWriter
|
11 |
+
import pickle
|
12 |
+
|
13 |
+
|
14 |
+
class MyDQN(nn.Module):
|
15 |
+
"""mini cnn structure
|
16 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, input_dim, output_dim):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.online = nn.Sequential(
|
23 |
+
nn.Linear(input_dim, 128),
|
24 |
+
nn.ReLU(),
|
25 |
+
nn.Linear(128, 128),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Linear(128, output_dim)
|
28 |
+
)
|
29 |
+
|
30 |
+
|
31 |
+
self.target = copy.deepcopy(self.online)
|
32 |
+
|
33 |
+
# Q_target parameters are frozen.
|
34 |
+
for p in self.target.parameters():
|
35 |
+
p.requires_grad = False
|
36 |
+
|
37 |
+
def forward(self, input, model):
|
38 |
+
if model == "online":
|
39 |
+
return self.online(input)
|
40 |
+
elif model == "target":
|
41 |
+
return self.target(input)
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
class MetricLogger:
|
46 |
+
def __init__(self, save_dir):
|
47 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
48 |
+
self.save_log = save_dir / "log"
|
49 |
+
with open(self.save_log, "w") as f:
|
50 |
+
f.write(
|
51 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
52 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
53 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
54 |
+
)
|
55 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
56 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
57 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
58 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
59 |
+
|
60 |
+
# History metrics
|
61 |
+
self.ep_rewards = []
|
62 |
+
self.ep_lengths = []
|
63 |
+
self.ep_avg_losses = []
|
64 |
+
self.ep_avg_qs = []
|
65 |
+
|
66 |
+
# Moving averages, added for every call to record()
|
67 |
+
self.moving_avg_ep_rewards = []
|
68 |
+
self.moving_avg_ep_lengths = []
|
69 |
+
self.moving_avg_ep_avg_losses = []
|
70 |
+
self.moving_avg_ep_avg_qs = []
|
71 |
+
|
72 |
+
# Current episode metric
|
73 |
+
self.init_episode()
|
74 |
+
|
75 |
+
# Timing
|
76 |
+
self.record_time = time.time()
|
77 |
+
|
78 |
+
def log_step(self, reward, loss, q):
|
79 |
+
self.curr_ep_reward += reward
|
80 |
+
self.curr_ep_length += 1
|
81 |
+
if loss:
|
82 |
+
self.curr_ep_loss += loss
|
83 |
+
self.curr_ep_q += q
|
84 |
+
self.curr_ep_loss_length += 1
|
85 |
+
|
86 |
+
def log_episode(self, episode_number):
|
87 |
+
"Mark end of episode"
|
88 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
89 |
+
self.ep_lengths.append(self.curr_ep_length)
|
90 |
+
if self.curr_ep_loss_length == 0:
|
91 |
+
ep_avg_loss = 0
|
92 |
+
ep_avg_q = 0
|
93 |
+
else:
|
94 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
95 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
96 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
97 |
+
self.ep_avg_qs.append(ep_avg_q)
|
98 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
99 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
100 |
+
self.writer.flush()
|
101 |
+
self.init_episode()
|
102 |
+
|
103 |
+
def init_episode(self):
|
104 |
+
self.curr_ep_reward = 0.0
|
105 |
+
self.curr_ep_length = 0
|
106 |
+
self.curr_ep_loss = 0.0
|
107 |
+
self.curr_ep_q = 0.0
|
108 |
+
self.curr_ep_loss_length = 0
|
109 |
+
|
110 |
+
def record(self, episode, epsilon, step):
|
111 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
112 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
113 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
114 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
115 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
116 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
117 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
118 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
119 |
+
|
120 |
+
last_record_time = self.record_time
|
121 |
+
self.record_time = time.time()
|
122 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
123 |
+
|
124 |
+
print(
|
125 |
+
f"Episode {episode} - "
|
126 |
+
f"Step {step} - "
|
127 |
+
f"Epsilon {epsilon} - "
|
128 |
+
f"Mean Reward {mean_ep_reward} - "
|
129 |
+
f"Mean Length {mean_ep_length} - "
|
130 |
+
f"Mean Loss {mean_ep_loss} - "
|
131 |
+
f"Mean Q Value {mean_ep_q} - "
|
132 |
+
f"Time Delta {time_since_last_record} - "
|
133 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
134 |
+
)
|
135 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
136 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
137 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
138 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
139 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
140 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
141 |
+
self.writer.flush()
|
142 |
+
with open(self.save_log, "a") as f:
|
143 |
+
f.write(
|
144 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
145 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
146 |
+
f"{time_since_last_record:15.3f}"
|
147 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
148 |
+
)
|
149 |
+
|
150 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
151 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
152 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
153 |
+
plt.clf()
|
154 |
+
|
155 |
+
|
156 |
+
class MyAgent:
|
157 |
+
def __init__(self, state_dim, action_dim, save_dir, checkpoint=None, reset_exploration_rate=False, max_memory_size=100000):
|
158 |
+
self.state_dim = state_dim
|
159 |
+
self.action_dim = action_dim
|
160 |
+
self.max_memory_size = max_memory_size
|
161 |
+
self.memory = deque(maxlen=max_memory_size)
|
162 |
+
# self.batch_size = 32
|
163 |
+
self.batch_size = 512
|
164 |
+
|
165 |
+
self.exploration_rate = 1
|
166 |
+
# self.exploration_rate_decay = 0.99999975
|
167 |
+
self.exploration_rate_decay = 0.9999999
|
168 |
+
self.exploration_rate_min = 0.1
|
169 |
+
self.gamma = 0.9
|
170 |
+
|
171 |
+
self.curr_step = 0
|
172 |
+
self.learning_start_threshold = 10000 # min. experiences before training
|
173 |
+
|
174 |
+
self.learn_every = 5 # no. of experiences between updates to Q_online
|
175 |
+
self.sync_every = 200 # no. of experiences between Q_target & Q_online sync
|
176 |
+
|
177 |
+
self.save_every = 200000 # no. of experiences between saving Mario Net
|
178 |
+
self.save_dir = save_dir
|
179 |
+
|
180 |
+
self.use_cuda = torch.cuda.is_available()
|
181 |
+
|
182 |
+
# Mario's DNN to predict the most optimal action - we implement this in the Learn section
|
183 |
+
self.net = MyDQN(self.state_dim, self.action_dim).float()
|
184 |
+
if self.use_cuda:
|
185 |
+
self.net = self.net.to(device='cuda')
|
186 |
+
if checkpoint:
|
187 |
+
self.load(checkpoint, reset_exploration_rate)
|
188 |
+
|
189 |
+
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
|
190 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=0.00025, amsgrad=True)
|
191 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
192 |
+
|
193 |
+
|
194 |
+
def act(self, state):
|
195 |
+
"""
|
196 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
197 |
+
|
198 |
+
Inputs:
|
199 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
200 |
+
Outputs:
|
201 |
+
action_idx (int): An integer representing which action Mario will perform
|
202 |
+
"""
|
203 |
+
# EXPLORE
|
204 |
+
if np.random.rand() < self.exploration_rate:
|
205 |
+
action_idx = np.random.randint(self.action_dim)
|
206 |
+
|
207 |
+
# EXPLOIT
|
208 |
+
else:
|
209 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
210 |
+
state = state.unsqueeze(0)
|
211 |
+
action_values = self.net(state, model='online')
|
212 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
213 |
+
|
214 |
+
# decrease exploration_rate
|
215 |
+
self.exploration_rate *= self.exploration_rate_decay
|
216 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
217 |
+
|
218 |
+
# increment step
|
219 |
+
self.curr_step += 1
|
220 |
+
return action_idx
|
221 |
+
|
222 |
+
def cache(self, state, next_state, action, reward, done):
|
223 |
+
"""
|
224 |
+
Store the experience to self.memory (replay buffer)
|
225 |
+
|
226 |
+
Inputs:
|
227 |
+
state (LazyFrame),
|
228 |
+
next_state (LazyFrame),
|
229 |
+
action (int),
|
230 |
+
reward (float),
|
231 |
+
done(bool))
|
232 |
+
"""
|
233 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
234 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
235 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
236 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
237 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
238 |
+
|
239 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
240 |
+
|
241 |
+
|
242 |
+
def recall(self):
|
243 |
+
"""
|
244 |
+
Retrieve a batch of experiences from memory
|
245 |
+
"""
|
246 |
+
batch = random.sample(self.memory, self.batch_size)
|
247 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
248 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
249 |
+
|
250 |
+
|
251 |
+
# def td_estimate(self, state, action):
|
252 |
+
# current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
|
253 |
+
# return current_Q
|
254 |
+
|
255 |
+
|
256 |
+
# @torch.no_grad()
|
257 |
+
# def td_target(self, reward, next_state, done):
|
258 |
+
# next_state_Q = self.net(next_state, model='online')
|
259 |
+
# best_action = torch.argmax(next_state_Q, axis=1)
|
260 |
+
# next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
|
261 |
+
# return (reward + (1 - done.float()) * self.gamma * next_Q).float()
|
262 |
+
|
263 |
+
def td_estimate(self, states, actions):
|
264 |
+
actions = actions.reshape(-1, 1)
|
265 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
266 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
267 |
+
return predicted_qs
|
268 |
+
|
269 |
+
|
270 |
+
@torch.no_grad()
|
271 |
+
def td_target(self, rewards, next_states, dones):
|
272 |
+
rewards = rewards.reshape(-1, 1)
|
273 |
+
dones = dones.reshape(-1, 1)
|
274 |
+
target_qs = self.net(next_states, model='target')
|
275 |
+
target_qs = torch.max(target_qs, dim=1).values
|
276 |
+
target_qs = target_qs.reshape(-1, 1)
|
277 |
+
target_qs[dones] = 0.0
|
278 |
+
return (rewards + (self.gamma * target_qs))
|
279 |
+
|
280 |
+
def update_Q_online(self, td_estimate, td_target) :
|
281 |
+
loss = self.loss_fn(td_estimate, td_target)
|
282 |
+
self.optimizer.zero_grad()
|
283 |
+
loss.backward()
|
284 |
+
self.optimizer.step()
|
285 |
+
return loss.item()
|
286 |
+
|
287 |
+
|
288 |
+
def sync_Q_target(self):
|
289 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
290 |
+
|
291 |
+
|
292 |
+
def learn(self):
|
293 |
+
if self.curr_step % self.sync_every == 0:
|
294 |
+
self.sync_Q_target()
|
295 |
+
|
296 |
+
if self.curr_step % self.save_every == 0:
|
297 |
+
self.save()
|
298 |
+
|
299 |
+
if self.curr_step < self.learning_start_threshold:
|
300 |
+
return None, None
|
301 |
+
|
302 |
+
if self.curr_step % self.learn_every != 0:
|
303 |
+
return None, None
|
304 |
+
|
305 |
+
# Sample from memory
|
306 |
+
state, next_state, action, reward, done = self.recall()
|
307 |
+
|
308 |
+
# Get TD Estimate
|
309 |
+
td_est = self.td_estimate(state, action)
|
310 |
+
|
311 |
+
# Get TD Target
|
312 |
+
td_tgt = self.td_target(reward, next_state, done)
|
313 |
+
|
314 |
+
# Backpropagate loss through Q_online
|
315 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
316 |
+
|
317 |
+
return (td_est.mean().item(), loss)
|
318 |
+
|
319 |
+
|
320 |
+
def save(self):
|
321 |
+
save_path = self.save_dir / f"cartpole_net_{int(self.curr_step // self.save_every)}.chkpt"
|
322 |
+
torch.save(
|
323 |
+
dict(
|
324 |
+
model=self.net.state_dict(),
|
325 |
+
exploration_rate=self.exploration_rate,
|
326 |
+
replay_memory=self.memory
|
327 |
+
),
|
328 |
+
save_path
|
329 |
+
)
|
330 |
+
|
331 |
+
print(f"Cartpole Net saved to {save_path} at step {self.curr_step}")
|
332 |
+
|
333 |
+
|
334 |
+
def load(self, load_path, reset_exploration_rate=False):
|
335 |
+
if not load_path.exists():
|
336 |
+
raise ValueError(f"{load_path} does not exist")
|
337 |
+
|
338 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
339 |
+
exploration_rate = ckp.get('exploration_rate')
|
340 |
+
state_dict = ckp.get('model')
|
341 |
+
replay_memory = ckp.get('replay_memory')
|
342 |
+
|
343 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
344 |
+
self.net.load_state_dict(state_dict)
|
345 |
+
|
346 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
347 |
+
self.memory = replay_memory if replay_memory else self.memory
|
348 |
+
|
349 |
+
if reset_exploration_rate:
|
350 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
351 |
+
else:
|
352 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
353 |
+
self.exploration_rate = exploration_rate
|
src/airstriker-genesis/procgen_agent.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch.nn as nn
|
5 |
+
import copy
|
6 |
+
import time, datetime
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from collections import deque
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
import pickle
|
11 |
+
|
12 |
+
|
13 |
+
class DQNet(nn.Module):
|
14 |
+
"""mini cnn structure
|
15 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, input_dim, output_dim):
|
19 |
+
super().__init__()
|
20 |
+
print("#################################")
|
21 |
+
print("#################################")
|
22 |
+
print(input_dim)
|
23 |
+
print(output_dim)
|
24 |
+
print("#################################")
|
25 |
+
print("#################################")
|
26 |
+
c, h, w = input_dim
|
27 |
+
|
28 |
+
# if h != 84:
|
29 |
+
# raise ValueError(f"Expecting input height: 84, got: {h}")
|
30 |
+
# if w != 84:
|
31 |
+
# raise ValueError(f"Expecting input width: 84, got: {w}")
|
32 |
+
|
33 |
+
self.online = nn.Sequential(
|
34 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
37 |
+
nn.ReLU(),
|
38 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
39 |
+
nn.ReLU(),
|
40 |
+
nn.Flatten(),
|
41 |
+
nn.Linear(7168, 512),
|
42 |
+
nn.ReLU(),
|
43 |
+
nn.Linear(512, output_dim),
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
self.target = copy.deepcopy(self.online)
|
48 |
+
|
49 |
+
# Q_target parameters are frozen.
|
50 |
+
for p in self.target.parameters():
|
51 |
+
p.requires_grad = False
|
52 |
+
|
53 |
+
def forward(self, input, model):
|
54 |
+
if model == "online":
|
55 |
+
return self.online(input)
|
56 |
+
elif model == "target":
|
57 |
+
return self.target(input)
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
class MetricLogger:
|
62 |
+
def __init__(self, save_dir):
|
63 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
64 |
+
self.save_log = save_dir / "log"
|
65 |
+
with open(self.save_log, "w") as f:
|
66 |
+
f.write(
|
67 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
68 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
69 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
70 |
+
)
|
71 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
72 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
73 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
74 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
75 |
+
|
76 |
+
# History metrics
|
77 |
+
self.ep_rewards = []
|
78 |
+
self.ep_lengths = []
|
79 |
+
self.ep_avg_losses = []
|
80 |
+
self.ep_avg_qs = []
|
81 |
+
|
82 |
+
# Moving averages, added for every call to record()
|
83 |
+
self.moving_avg_ep_rewards = []
|
84 |
+
self.moving_avg_ep_lengths = []
|
85 |
+
self.moving_avg_ep_avg_losses = []
|
86 |
+
self.moving_avg_ep_avg_qs = []
|
87 |
+
|
88 |
+
# Current episode metric
|
89 |
+
self.init_episode()
|
90 |
+
|
91 |
+
# Timing
|
92 |
+
self.record_time = time.time()
|
93 |
+
|
94 |
+
def log_step(self, reward, loss, q):
|
95 |
+
self.curr_ep_reward += reward
|
96 |
+
self.curr_ep_length += 1
|
97 |
+
if loss:
|
98 |
+
self.curr_ep_loss += loss
|
99 |
+
self.curr_ep_q += q
|
100 |
+
self.curr_ep_loss_length += 1
|
101 |
+
|
102 |
+
def log_episode(self, episode_number):
|
103 |
+
"Mark end of episode"
|
104 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
105 |
+
self.ep_lengths.append(self.curr_ep_length)
|
106 |
+
if self.curr_ep_loss_length == 0:
|
107 |
+
ep_avg_loss = 0
|
108 |
+
ep_avg_q = 0
|
109 |
+
else:
|
110 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
111 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
112 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
113 |
+
self.ep_avg_qs.append(ep_avg_q)
|
114 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
115 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
116 |
+
self.writer.flush()
|
117 |
+
self.init_episode()
|
118 |
+
|
119 |
+
def init_episode(self):
|
120 |
+
self.curr_ep_reward = 0.0
|
121 |
+
self.curr_ep_length = 0
|
122 |
+
self.curr_ep_loss = 0.0
|
123 |
+
self.curr_ep_q = 0.0
|
124 |
+
self.curr_ep_loss_length = 0
|
125 |
+
|
126 |
+
def record(self, episode, epsilon, step):
|
127 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
128 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
129 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
130 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
131 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
132 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
133 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
134 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
135 |
+
|
136 |
+
last_record_time = self.record_time
|
137 |
+
self.record_time = time.time()
|
138 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
139 |
+
|
140 |
+
print(
|
141 |
+
f"Episode {episode} - "
|
142 |
+
f"Step {step} - "
|
143 |
+
f"Epsilon {epsilon} - "
|
144 |
+
f"Mean Reward {mean_ep_reward} - "
|
145 |
+
f"Mean Length {mean_ep_length} - "
|
146 |
+
f"Mean Loss {mean_ep_loss} - "
|
147 |
+
f"Mean Q Value {mean_ep_q} - "
|
148 |
+
f"Time Delta {time_since_last_record} - "
|
149 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
150 |
+
)
|
151 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
152 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
153 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
154 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
155 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
156 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
157 |
+
self.writer.flush()
|
158 |
+
with open(self.save_log, "a") as f:
|
159 |
+
f.write(
|
160 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
161 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
162 |
+
f"{time_since_last_record:15.3f}"
|
163 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
164 |
+
)
|
165 |
+
|
166 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
167 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
168 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
169 |
+
plt.clf()
|
170 |
+
|
171 |
+
|
172 |
+
class DQNAgent:
|
173 |
+
def __init__(self,
|
174 |
+
state_dim,
|
175 |
+
action_dim,
|
176 |
+
save_dir,
|
177 |
+
checkpoint=None,
|
178 |
+
learning_rate=0.00025,
|
179 |
+
max_memory_size=100000,
|
180 |
+
batch_size=32,
|
181 |
+
exploration_rate=1,
|
182 |
+
exploration_rate_decay=0.9999999,
|
183 |
+
exploration_rate_min=0.1,
|
184 |
+
training_frequency=1,
|
185 |
+
learning_starts=1000,
|
186 |
+
target_network_sync_frequency=500,
|
187 |
+
reset_exploration_rate=False,
|
188 |
+
save_frequency=100000,
|
189 |
+
gamma=0.9,
|
190 |
+
load_replay_buffer=True):
|
191 |
+
self.state_dim = state_dim
|
192 |
+
self.action_dim = action_dim
|
193 |
+
self.max_memory_size = max_memory_size
|
194 |
+
self.memory = deque(maxlen=max_memory_size)
|
195 |
+
self.batch_size = batch_size
|
196 |
+
|
197 |
+
self.exploration_rate = exploration_rate
|
198 |
+
self.exploration_rate_decay = exploration_rate_decay
|
199 |
+
self.exploration_rate_min = exploration_rate_min
|
200 |
+
self.gamma = gamma
|
201 |
+
|
202 |
+
self.curr_step = 0
|
203 |
+
self.learning_starts = learning_starts # min. experiences before training
|
204 |
+
|
205 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
206 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
207 |
+
|
208 |
+
self.save_every = save_frequency # no. of experiences between saving Mario Net
|
209 |
+
self.save_dir = save_dir
|
210 |
+
|
211 |
+
self.use_cuda = torch.cuda.is_available()
|
212 |
+
|
213 |
+
# Mario's DNN to predict the most optimal action - we implement this in the Learn section
|
214 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
215 |
+
if self.use_cuda:
|
216 |
+
self.net = self.net.to(device='cuda')
|
217 |
+
if checkpoint:
|
218 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
219 |
+
|
220 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
221 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
222 |
+
|
223 |
+
|
224 |
+
def act(self, state):
|
225 |
+
"""
|
226 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
227 |
+
|
228 |
+
Inputs:
|
229 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
230 |
+
Outputs:
|
231 |
+
action_idx (int): An integer representing which action Mario will perform
|
232 |
+
"""
|
233 |
+
# EXPLORE
|
234 |
+
if np.random.rand() < self.exploration_rate:
|
235 |
+
action_idx = np.random.randint(self.action_dim)
|
236 |
+
|
237 |
+
# EXPLOIT
|
238 |
+
else:
|
239 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
240 |
+
state = state.unsqueeze(0)
|
241 |
+
action_values = self.net(state, model='online')
|
242 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
243 |
+
|
244 |
+
# decrease exploration_rate
|
245 |
+
self.exploration_rate *= self.exploration_rate_decay
|
246 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
247 |
+
|
248 |
+
# increment step
|
249 |
+
self.curr_step += 1
|
250 |
+
return action_idx
|
251 |
+
|
252 |
+
def cache(self, state, next_state, action, reward, done):
|
253 |
+
"""
|
254 |
+
Store the experience to self.memory (replay buffer)
|
255 |
+
|
256 |
+
Inputs:
|
257 |
+
state (LazyFrame),
|
258 |
+
next_state (LazyFrame),
|
259 |
+
action (int),
|
260 |
+
reward (float),
|
261 |
+
done(bool))
|
262 |
+
"""
|
263 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
264 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
265 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
266 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
267 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
268 |
+
|
269 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
270 |
+
|
271 |
+
|
272 |
+
def recall(self):
|
273 |
+
"""
|
274 |
+
Retrieve a batch of experiences from memory
|
275 |
+
"""
|
276 |
+
batch = random.sample(self.memory, self.batch_size)
|
277 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
278 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
279 |
+
|
280 |
+
|
281 |
+
# def td_estimate(self, state, action):
|
282 |
+
# current_Q = self.net(state, model='online')[np.arange(0, self.batch_size), action] # Q_online(s,a)
|
283 |
+
# return current_Q
|
284 |
+
|
285 |
+
|
286 |
+
# @torch.no_grad()
|
287 |
+
# def td_target(self, reward, next_state, done):
|
288 |
+
# next_state_Q = self.net(next_state, model='online')
|
289 |
+
# best_action = torch.argmax(next_state_Q, axis=1)
|
290 |
+
# next_Q = self.net(next_state, model='target')[np.arange(0, self.batch_size), best_action]
|
291 |
+
# return (reward + (1 - done.float()) * self.gamma * next_Q).float()
|
292 |
+
|
293 |
+
def td_estimate(self, states, actions):
|
294 |
+
actions = actions.reshape(-1, 1)
|
295 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
296 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
297 |
+
return predicted_qs
|
298 |
+
|
299 |
+
|
300 |
+
@torch.no_grad()
|
301 |
+
def td_target(self, rewards, next_states, dones):
|
302 |
+
rewards = rewards.reshape(-1, 1)
|
303 |
+
dones = dones.reshape(-1, 1)
|
304 |
+
target_qs = self.net(next_states, model='target')
|
305 |
+
target_qs = torch.max(target_qs, dim=1).values
|
306 |
+
target_qs = target_qs.reshape(-1, 1)
|
307 |
+
target_qs[dones] = 0.0
|
308 |
+
return (rewards + (self.gamma * target_qs))
|
309 |
+
|
310 |
+
def update_Q_online(self, td_estimate, td_target) :
|
311 |
+
loss = self.loss_fn(td_estimate, td_target)
|
312 |
+
self.optimizer.zero_grad()
|
313 |
+
loss.backward()
|
314 |
+
self.optimizer.step()
|
315 |
+
return loss.item()
|
316 |
+
|
317 |
+
|
318 |
+
def sync_Q_target(self):
|
319 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
320 |
+
|
321 |
+
|
322 |
+
def learn(self):
|
323 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
324 |
+
self.sync_Q_target()
|
325 |
+
|
326 |
+
if self.curr_step % self.save_every == 0:
|
327 |
+
self.save()
|
328 |
+
|
329 |
+
if self.curr_step < self.learning_starts:
|
330 |
+
return None, None
|
331 |
+
|
332 |
+
if self.curr_step % self.training_frequency != 0:
|
333 |
+
return None, None
|
334 |
+
|
335 |
+
# Sample from memory
|
336 |
+
state, next_state, action, reward, done = self.recall()
|
337 |
+
|
338 |
+
# Get TD Estimate
|
339 |
+
td_est = self.td_estimate(state, action)
|
340 |
+
|
341 |
+
# Get TD Target
|
342 |
+
td_tgt = self.td_target(reward, next_state, done)
|
343 |
+
|
344 |
+
# Backpropagate loss through Q_online
|
345 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
346 |
+
|
347 |
+
return (td_est.mean().item(), loss)
|
348 |
+
|
349 |
+
|
350 |
+
def save(self):
|
351 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
352 |
+
torch.save(
|
353 |
+
dict(
|
354 |
+
model=self.net.state_dict(),
|
355 |
+
exploration_rate=self.exploration_rate,
|
356 |
+
replay_memory=self.memory
|
357 |
+
),
|
358 |
+
save_path
|
359 |
+
)
|
360 |
+
|
361 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
362 |
+
|
363 |
+
|
364 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
365 |
+
if not load_path.exists():
|
366 |
+
raise ValueError(f"{load_path} does not exist")
|
367 |
+
|
368 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
369 |
+
exploration_rate = ckp.get('exploration_rate')
|
370 |
+
state_dict = ckp.get('model')
|
371 |
+
|
372 |
+
|
373 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
374 |
+
self.net.load_state_dict(state_dict)
|
375 |
+
|
376 |
+
if load_replay_buffer:
|
377 |
+
replay_memory = ckp.get('replay_memory')
|
378 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
379 |
+
self.memory = replay_memory if replay_memory else self.memory
|
380 |
+
|
381 |
+
if reset_exploration_rate:
|
382 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
383 |
+
else:
|
384 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
385 |
+
self.exploration_rate = exploration_rate
|
386 |
+
|
387 |
+
|
388 |
+
class DDQNAgent(DQNAgent):
|
389 |
+
@torch.no_grad()
|
390 |
+
def td_target(self, rewards, next_states, dones):
|
391 |
+
print("Double dqn -----------------------")
|
392 |
+
rewards = rewards.reshape(-1, 1)
|
393 |
+
dones = dones.reshape(-1, 1)
|
394 |
+
q_vals = self.net(next_states, model='online')
|
395 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
396 |
+
target_actions = target_actions.reshape(-1, 1)
|
397 |
+
target_qs = self.net(next_states, model='target').gather(target_actions, 1)
|
398 |
+
target_qs = target_qs.reshape(-1, 1)
|
399 |
+
target_qs[dones] = 0.0
|
400 |
+
return (rewards + (self.gamma * target_qs))
|
src/airstriker-genesis/replay.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from pathlib import Path
|
3 |
+
from itertools import count
|
4 |
+
from agent import DQNAgent, MetricLogger
|
5 |
+
from wrappers import make_env, make_starpilot
|
6 |
+
|
7 |
+
|
8 |
+
env = make_starpilot()
|
9 |
+
|
10 |
+
env.reset()
|
11 |
+
|
12 |
+
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
13 |
+
save_dir.mkdir(parents=True)
|
14 |
+
|
15 |
+
checkpoint = Path('checkpoints/procgen-starpilot-dqn/airstriker_net_3.chkpt')
|
16 |
+
|
17 |
+
agent = DQNAgent(
|
18 |
+
state_dim=(1, 64, 64),
|
19 |
+
action_dim=env.action_space.n,
|
20 |
+
save_dir=save_dir,
|
21 |
+
batch_size=256,
|
22 |
+
checkpoint=checkpoint,
|
23 |
+
reset_exploration_rate=True,
|
24 |
+
exploration_rate_decay=0.999999,
|
25 |
+
training_frequency=10,
|
26 |
+
target_network_sync_frequency=200,
|
27 |
+
max_memory_size=3000,
|
28 |
+
learning_rate=0.001,
|
29 |
+
save_frequency=2000
|
30 |
+
|
31 |
+
)
|
32 |
+
agent.exploration_rate = agent.exploration_rate_min
|
33 |
+
|
34 |
+
# logger = MetricLogger(save_dir)
|
35 |
+
|
36 |
+
episodes = 100
|
37 |
+
|
38 |
+
for e in range(episodes):
|
39 |
+
|
40 |
+
state = env.reset()
|
41 |
+
|
42 |
+
while True:
|
43 |
+
|
44 |
+
env.render()
|
45 |
+
|
46 |
+
action = agent.act(state)
|
47 |
+
|
48 |
+
next_state, reward, done, info = env.step(action)
|
49 |
+
|
50 |
+
agent.cache(state, next_state, action, reward, done)
|
51 |
+
|
52 |
+
# logger.log_step(reward, None, None)
|
53 |
+
|
54 |
+
state = next_state
|
55 |
+
|
56 |
+
if done:
|
57 |
+
break
|
58 |
+
|
59 |
+
# logger.log_episode()
|
60 |
+
|
61 |
+
# if e % 20 == 0:
|
62 |
+
# logger.record(
|
63 |
+
# episode=e,
|
64 |
+
# epsilon=agent.exploration_rate,
|
65 |
+
# step=agent.curr_step
|
66 |
+
# )
|
src/airstriker-genesis/run-airstriker-ddqn.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import trange
|
8 |
+
from agent import DQNAgent, DDQNAgent, MetricLogger
|
9 |
+
from wrappers import make_env
|
10 |
+
|
11 |
+
|
12 |
+
# set up matplotlib
|
13 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
14 |
+
if is_ipython:
|
15 |
+
from IPython import display
|
16 |
+
|
17 |
+
plt.ion()
|
18 |
+
|
19 |
+
|
20 |
+
env = make_env()
|
21 |
+
|
22 |
+
use_cuda = torch.cuda.is_available()
|
23 |
+
print(f"Using CUDA: {use_cuda}\n")
|
24 |
+
|
25 |
+
|
26 |
+
checkpoint = None
|
27 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
28 |
+
|
29 |
+
path = "checkpoints/airstriker-ddqn"
|
30 |
+
save_dir = Path(path)
|
31 |
+
|
32 |
+
isExist = os.path.exists(path)
|
33 |
+
if not isExist:
|
34 |
+
os.makedirs(path)
|
35 |
+
|
36 |
+
# Vanilla DQN
|
37 |
+
print("Training Vanilla DQN Agent!")
|
38 |
+
# agent = DQNAgent(
|
39 |
+
# state_dim=(1, 84, 84),
|
40 |
+
# action_dim=env.action_space.n,
|
41 |
+
# save_dir=save_dir,
|
42 |
+
# batch_size=128,
|
43 |
+
# checkpoint=checkpoint,
|
44 |
+
# exploration_rate_decay=0.995,
|
45 |
+
# exploration_rate_min=0.05,
|
46 |
+
# training_frequency=1,
|
47 |
+
# target_network_sync_frequency=500,
|
48 |
+
# max_memory_size=50000,
|
49 |
+
# learning_rate=0.0005,
|
50 |
+
|
51 |
+
# )
|
52 |
+
|
53 |
+
# Double DQN
|
54 |
+
print("Training DDQN Agent!")
|
55 |
+
agent = DDQNAgent(
|
56 |
+
state_dim=(1, 84, 84),
|
57 |
+
action_dim=env.action_space.n,
|
58 |
+
save_dir=save_dir,
|
59 |
+
batch_size=128,
|
60 |
+
checkpoint=checkpoint,
|
61 |
+
exploration_rate_decay=0.995,
|
62 |
+
exploration_rate_min=0.05,
|
63 |
+
training_frequency=1,
|
64 |
+
target_network_sync_frequency=500,
|
65 |
+
max_memory_size=50000,
|
66 |
+
learning_rate=0.0005,
|
67 |
+
)
|
68 |
+
|
69 |
+
logger = MetricLogger(save_dir)
|
70 |
+
|
71 |
+
def fill_memory(agent: DQNAgent, num_episodes=1000):
|
72 |
+
print("Filling up memory....")
|
73 |
+
for _ in trange(num_episodes):
|
74 |
+
state = env.reset()
|
75 |
+
done = False
|
76 |
+
while not done:
|
77 |
+
action = agent.act(state)
|
78 |
+
next_state, reward, done, _ = env.step(action)
|
79 |
+
agent.cache(state, next_state, action, reward, done)
|
80 |
+
state = next_state
|
81 |
+
|
82 |
+
|
83 |
+
def train(agent: DQNAgent):
|
84 |
+
episodes = 10000000
|
85 |
+
for e in range(episodes):
|
86 |
+
|
87 |
+
state = env.reset()
|
88 |
+
# Play the game!
|
89 |
+
while True:
|
90 |
+
|
91 |
+
# print(state.shape)
|
92 |
+
# Run agent on the state
|
93 |
+
action = agent.act(state)
|
94 |
+
|
95 |
+
# Agent performs action
|
96 |
+
next_state, reward, done, info = env.step(action)
|
97 |
+
|
98 |
+
# Remember
|
99 |
+
agent.cache(state, next_state, action, reward, done)
|
100 |
+
|
101 |
+
# Learn
|
102 |
+
q, loss = agent.learn()
|
103 |
+
|
104 |
+
# Logging
|
105 |
+
logger.log_step(reward, loss, q)
|
106 |
+
|
107 |
+
# Update state
|
108 |
+
state = next_state
|
109 |
+
|
110 |
+
# Check if end of game
|
111 |
+
if done or info["gameover"] == 1:
|
112 |
+
break
|
113 |
+
|
114 |
+
logger.log_episode(e)
|
115 |
+
|
116 |
+
if e % 20 == 0:
|
117 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
118 |
+
|
119 |
+
fill_memory(agent)
|
120 |
+
train(agent)
|
src/airstriker-genesis/run-airstriker-dqn.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
from tqdm import trange
|
8 |
+
from agent import DQNAgent, DDQNAgent, MetricLogger
|
9 |
+
from wrappers import make_env
|
10 |
+
|
11 |
+
|
12 |
+
# set up matplotlib
|
13 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
14 |
+
if is_ipython:
|
15 |
+
from IPython import display
|
16 |
+
|
17 |
+
plt.ion()
|
18 |
+
|
19 |
+
|
20 |
+
env = make_env()
|
21 |
+
|
22 |
+
use_cuda = torch.cuda.is_available()
|
23 |
+
print(f"Using CUDA: {use_cuda}\n")
|
24 |
+
|
25 |
+
|
26 |
+
checkpoint = None
|
27 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
28 |
+
|
29 |
+
path = "checkpoints/airstriker-dqn-new"
|
30 |
+
save_dir = Path(path)
|
31 |
+
|
32 |
+
isExist = os.path.exists(path)
|
33 |
+
if not isExist:
|
34 |
+
os.makedirs(path)
|
35 |
+
|
36 |
+
# Vanilla DQN
|
37 |
+
print("Training Vanilla DQN Agent!")
|
38 |
+
agent = DQNAgent(
|
39 |
+
state_dim=(1, 84, 84),
|
40 |
+
action_dim=env.action_space.n,
|
41 |
+
save_dir=save_dir,
|
42 |
+
batch_size=128,
|
43 |
+
checkpoint=checkpoint,
|
44 |
+
exploration_rate_decay=0.995,
|
45 |
+
exploration_rate_min=0.05,
|
46 |
+
training_frequency=1,
|
47 |
+
target_network_sync_frequency=500,
|
48 |
+
max_memory_size=50000,
|
49 |
+
learning_rate=0.0005,
|
50 |
+
|
51 |
+
)
|
52 |
+
|
53 |
+
# Double DQN
|
54 |
+
# print("Training DDQN Agent!")
|
55 |
+
# agent = DDQNAgent(
|
56 |
+
# state_dim=(1, 84, 84),
|
57 |
+
# action_dim=env.action_space.n,
|
58 |
+
# save_dir=save_dir,
|
59 |
+
# checkpoint=checkpoint,
|
60 |
+
# reset_exploration_rate=True,
|
61 |
+
# max_memory_size=max_memory_size
|
62 |
+
# )
|
63 |
+
|
64 |
+
logger = MetricLogger(save_dir)
|
65 |
+
|
66 |
+
def fill_memory(agent: DQNAgent, num_episodes=1000):
|
67 |
+
print("Filling up memory....")
|
68 |
+
for _ in trange(num_episodes):
|
69 |
+
state = env.reset()
|
70 |
+
done = False
|
71 |
+
while not done:
|
72 |
+
action = agent.act(state)
|
73 |
+
next_state, reward, done, _ = env.step(action)
|
74 |
+
agent.cache(state, next_state, action, reward, done)
|
75 |
+
state = next_state
|
76 |
+
|
77 |
+
|
78 |
+
def train(agent: DQNAgent):
|
79 |
+
episodes = 10000000
|
80 |
+
for e in range(episodes):
|
81 |
+
|
82 |
+
state = env.reset()
|
83 |
+
# Play the game!
|
84 |
+
while True:
|
85 |
+
|
86 |
+
# print(state.shape)
|
87 |
+
# Run agent on the state
|
88 |
+
action = agent.act(state)
|
89 |
+
|
90 |
+
# Agent performs action
|
91 |
+
next_state, reward, done, info = env.step(action)
|
92 |
+
|
93 |
+
# Remember
|
94 |
+
agent.cache(state, next_state, action, reward, done)
|
95 |
+
|
96 |
+
# Learn
|
97 |
+
q, loss = agent.learn()
|
98 |
+
|
99 |
+
# Logging
|
100 |
+
logger.log_step(reward, loss, q)
|
101 |
+
|
102 |
+
# Update state
|
103 |
+
state = next_state
|
104 |
+
|
105 |
+
# Check if end of game
|
106 |
+
if done or info["gameover"] == 1:
|
107 |
+
break
|
108 |
+
|
109 |
+
logger.log_episode(e)
|
110 |
+
|
111 |
+
if e % 20 == 0:
|
112 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
113 |
+
|
114 |
+
fill_memory(agent)
|
115 |
+
train(agent)
|
src/airstriker-genesis/run-cartpole.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random, datetime
|
3 |
+
from pathlib import Path
|
4 |
+
import retro as gym
|
5 |
+
from collections import namedtuple, deque
|
6 |
+
from itertools import count
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import matplotlib
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
# from agent import MyAgent, MyDQN, MetricLogger
|
12 |
+
from cartpole import MyAgent, MetricLogger
|
13 |
+
from wrappers import make_env
|
14 |
+
import pickle
|
15 |
+
import gym
|
16 |
+
from tqdm import trange
|
17 |
+
|
18 |
+
# set up matplotlib
|
19 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
20 |
+
if is_ipython:
|
21 |
+
from IPython import display
|
22 |
+
|
23 |
+
plt.ion()
|
24 |
+
|
25 |
+
|
26 |
+
# env = make_env()
|
27 |
+
env = gym.make('CartPole-v1')
|
28 |
+
|
29 |
+
use_cuda = torch.cuda.is_available()
|
30 |
+
print(f"Using CUDA: {use_cuda}")
|
31 |
+
print()
|
32 |
+
|
33 |
+
path = "checkpoints/cartpole/latest"
|
34 |
+
save_dir = Path(path)
|
35 |
+
|
36 |
+
isExist = os.path.exists(path)
|
37 |
+
if not isExist:
|
38 |
+
os.makedirs(path)
|
39 |
+
|
40 |
+
# save_dir.mkdir(parents=True)
|
41 |
+
|
42 |
+
|
43 |
+
checkpoint = None
|
44 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
45 |
+
|
46 |
+
# For cartpole
|
47 |
+
n_actions = env.action_space.n
|
48 |
+
state = env.reset()
|
49 |
+
n_observations = len(state)
|
50 |
+
max_memory_size=100000
|
51 |
+
agent = MyAgent(
|
52 |
+
state_dim=n_observations,
|
53 |
+
action_dim=n_actions,
|
54 |
+
save_dir=save_dir,
|
55 |
+
checkpoint=checkpoint,
|
56 |
+
reset_exploration_rate=True,
|
57 |
+
max_memory_size=max_memory_size
|
58 |
+
)
|
59 |
+
|
60 |
+
# For airstriker
|
61 |
+
# agent = MyAgent(state_dim=(1, 84, 84), action_dim=env.action_space.n, save_dir=save_dir, checkpoint=checkpoint, reset_exploration_rate=True)
|
62 |
+
|
63 |
+
|
64 |
+
logger = MetricLogger(save_dir)
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
def fill_memory(agent: MyAgent):
|
69 |
+
print("Filling up memory....")
|
70 |
+
for _ in trange(max_memory_size):
|
71 |
+
state = env.reset()
|
72 |
+
done = False
|
73 |
+
while not done:
|
74 |
+
action = agent.act(state)
|
75 |
+
next_state, reward, done, info = env.step(action)
|
76 |
+
agent.cache(state, next_state, action, reward, done)
|
77 |
+
state = next_state
|
78 |
+
|
79 |
+
def train(agent: MyAgent):
|
80 |
+
episodes = 10000000
|
81 |
+
for e in range(episodes):
|
82 |
+
|
83 |
+
state = env.reset()
|
84 |
+
# Play the game!
|
85 |
+
while True:
|
86 |
+
|
87 |
+
# print(state.shape)
|
88 |
+
# Run agent on the state
|
89 |
+
action = agent.act(state)
|
90 |
+
|
91 |
+
# Agent performs action
|
92 |
+
next_state, reward, done, info = env.step(action)
|
93 |
+
|
94 |
+
# Remember
|
95 |
+
agent.cache(state, next_state, action, reward, done)
|
96 |
+
|
97 |
+
# Learn
|
98 |
+
q, loss = agent.learn()
|
99 |
+
|
100 |
+
# Logging
|
101 |
+
logger.log_step(reward, loss, q)
|
102 |
+
|
103 |
+
# Update state
|
104 |
+
state = next_state
|
105 |
+
|
106 |
+
# # Check if end of game (for airstriker)
|
107 |
+
# if done or info["gameover"] == 1:
|
108 |
+
# break
|
109 |
+
# Check if end of game (for cartpole)
|
110 |
+
if done:
|
111 |
+
break
|
112 |
+
|
113 |
+
logger.log_episode(e)
|
114 |
+
|
115 |
+
if e % 20 == 0:
|
116 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
117 |
+
|
118 |
+
|
119 |
+
fill_memory(agent)
|
120 |
+
train(agent)
|
src/airstriker-genesis/test.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import retro
|
2 |
+
import gym
|
3 |
+
import math
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
import matplotlib
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from collections import namedtuple, deque
|
9 |
+
from itertools import count
|
10 |
+
from gym import spaces
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.optim as optim
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import cv2
|
17 |
+
import torch
|
18 |
+
from torch.utils.tensorboard import SummaryWriter
|
19 |
+
|
20 |
+
|
21 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
22 |
+
def __init__(self, env, skip=4):
|
23 |
+
"""Return only every `skip`-th frame"""
|
24 |
+
gym.Wrapper.__init__(self, env)
|
25 |
+
# most recent raw observations (for max pooling across time steps)
|
26 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
27 |
+
self._skip = skip
|
28 |
+
|
29 |
+
def step(self, action):
|
30 |
+
"""Repeat action, sum reward, and max over last observations."""
|
31 |
+
total_reward = 0.0
|
32 |
+
done = None
|
33 |
+
for i in range(self._skip):
|
34 |
+
obs, reward, done, info = self.env.step(action)
|
35 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
36 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
37 |
+
total_reward += reward
|
38 |
+
if done:
|
39 |
+
break
|
40 |
+
# Note that the observation on the done=True frame
|
41 |
+
# doesn't matter
|
42 |
+
max_frame = self._obs_buffer.max(axis=0)
|
43 |
+
|
44 |
+
return max_frame, total_reward, done, info
|
45 |
+
|
46 |
+
def reset(self, **kwargs):
|
47 |
+
return self.env.reset(**kwargs)
|
48 |
+
|
49 |
+
|
50 |
+
class LazyFrames(object):
|
51 |
+
def __init__(self, frames):
|
52 |
+
"""This object ensures that common frames between the observations are only stored once.
|
53 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
54 |
+
buffers.
|
55 |
+
This object should only be converted to numpy array before being passed to the model.
|
56 |
+
You'd not believe how complex the previous solution was."""
|
57 |
+
self._frames = frames
|
58 |
+
self._out = None
|
59 |
+
|
60 |
+
def _force(self):
|
61 |
+
if self._out is None:
|
62 |
+
self._out = np.concatenate(self._frames, axis=2)
|
63 |
+
self._frames = None
|
64 |
+
return self._out
|
65 |
+
|
66 |
+
def __array__(self, dtype=None):
|
67 |
+
out = self._force()
|
68 |
+
if dtype is not None:
|
69 |
+
out = out.astype(dtype)
|
70 |
+
return out
|
71 |
+
|
72 |
+
def __len__(self):
|
73 |
+
return len(self._force())
|
74 |
+
|
75 |
+
def __getitem__(self, i):
|
76 |
+
return self._force()[i]
|
77 |
+
|
78 |
+
|
79 |
+
class FrameStack(gym.Wrapper):
|
80 |
+
def __init__(self, env, k):
|
81 |
+
"""Stack k last frames.
|
82 |
+
Returns lazy array, which is much more memory efficient.
|
83 |
+
See Also
|
84 |
+
--------
|
85 |
+
baselines.common.atari_wrappers.LazyFrames
|
86 |
+
"""
|
87 |
+
gym.Wrapper.__init__(self, env)
|
88 |
+
self.k = k
|
89 |
+
self.frames = deque([], maxlen=k)
|
90 |
+
shp = env.observation_space.shape
|
91 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
92 |
+
|
93 |
+
def reset(self):
|
94 |
+
ob = self.env.reset()
|
95 |
+
for _ in range(self.k):
|
96 |
+
self.frames.append(ob)
|
97 |
+
return self._get_ob()
|
98 |
+
|
99 |
+
def step(self, action):
|
100 |
+
ob, reward, done, info = self.env.step(action)
|
101 |
+
self.frames.append(ob)
|
102 |
+
return self._get_ob(), reward, done, info
|
103 |
+
|
104 |
+
def _get_ob(self):
|
105 |
+
assert len(self.frames) == self.k
|
106 |
+
return LazyFrames(list(self.frames))
|
107 |
+
|
108 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
109 |
+
def __init__(self, env):
|
110 |
+
gym.RewardWrapper.__init__(self, env)
|
111 |
+
|
112 |
+
def reward(self, reward):
|
113 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
114 |
+
return np.sign(reward)
|
115 |
+
|
116 |
+
|
117 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
118 |
+
def __init__(self, env):
|
119 |
+
super(ImageToPyTorch, self).__init__(env)
|
120 |
+
old_shape = self.observation_space.shape
|
121 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
122 |
+
|
123 |
+
def observation(self, observation):
|
124 |
+
return np.moveaxis(observation, 2, 0)
|
125 |
+
|
126 |
+
|
127 |
+
class WarpFrame(gym.ObservationWrapper):
|
128 |
+
def __init__(self, env):
|
129 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
130 |
+
gym.ObservationWrapper.__init__(self, env)
|
131 |
+
self.width = 84
|
132 |
+
self.height = 84
|
133 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
134 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
135 |
+
|
136 |
+
def observation(self, frame):
|
137 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
138 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
139 |
+
return frame[:, :, None]
|
140 |
+
|
141 |
+
class AirstrikerDiscretizer(gym.ActionWrapper):
|
142 |
+
# 初期化
|
143 |
+
def __init__(self, env):
|
144 |
+
super(AirstrikerDiscretizer, self).__init__(env)
|
145 |
+
buttons = ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
|
146 |
+
actions = [['LEFT'], ['RIGHT'], ['B']]
|
147 |
+
self._actions = []
|
148 |
+
for action in actions:
|
149 |
+
arr = np.array([False] * 12)
|
150 |
+
for button in action:
|
151 |
+
arr[buttons.index(button)] = True
|
152 |
+
self._actions.append(arr)
|
153 |
+
self.action_space = gym.spaces.Discrete(len(self._actions))
|
154 |
+
|
155 |
+
# 行動の取得
|
156 |
+
def action(self, a):
|
157 |
+
return self._actions[a].copy()
|
158 |
+
|
159 |
+
|
160 |
+
env = retro.make(game='Airstriker-Genesis')
|
161 |
+
env = MaxAndSkipEnv(env) ## Return only every `skip`-th frame
|
162 |
+
env = WarpFrame(env) ## Reshape image
|
163 |
+
env = ImageToPyTorch(env) ## Invert shape
|
164 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
165 |
+
# env = ScaledFloatFrame(env) ## Scale frames
|
166 |
+
env = AirstrikerDiscretizer(env)
|
167 |
+
env = ClipRewardEnv(env)
|
168 |
+
|
169 |
+
# set up matplotlib
|
170 |
+
is_ipython = 'inline' in matplotlib.get_backend()
|
171 |
+
if is_ipython:
|
172 |
+
from IPython import display
|
173 |
+
|
174 |
+
plt.ion()
|
175 |
+
|
176 |
+
# if gpu is to be used
|
177 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
178 |
+
|
179 |
+
Transition = namedtuple('Transition',
|
180 |
+
('state', 'action', 'next_state', 'reward'))
|
181 |
+
|
182 |
+
|
183 |
+
class ReplayMemory(object):
|
184 |
+
|
185 |
+
def __init__(self, capacity):
|
186 |
+
self.memory = deque([],maxlen=capacity)
|
187 |
+
|
188 |
+
def push(self, *args):
|
189 |
+
"""Save a transition"""
|
190 |
+
self.memory.append(Transition(*args))
|
191 |
+
|
192 |
+
def sample(self, batch_size):
|
193 |
+
return random.sample(self.memory, batch_size)
|
194 |
+
|
195 |
+
def __len__(self):
|
196 |
+
return len(self.memory)
|
197 |
+
|
198 |
+
|
199 |
+
class DQN(nn.Module):
|
200 |
+
|
201 |
+
def __init__(self, n_observations, n_actions):
|
202 |
+
super(DQN, self).__init__()
|
203 |
+
# self.layer1 = nn.Linear(n_observations, 128)
|
204 |
+
# self.layer2 = nn.Linear(128, 128)
|
205 |
+
# self.layer3 = nn.Linear(128, n_actions)
|
206 |
+
|
207 |
+
self.layer1 = nn.Conv2d(in_channels=n_observations, out_channels=32, kernel_size=8, stride=4)
|
208 |
+
self.layer2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2)
|
209 |
+
self.layer3 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1), nn.ReLU(), nn.Flatten())
|
210 |
+
self.layer4 = nn.Linear(17024, 512)
|
211 |
+
self.layer5 = nn.Linear(512, n_actions)
|
212 |
+
|
213 |
+
# Called with either one element to determine next action, or a batch
|
214 |
+
# during optimization. Returns tensor([[left0exp,right0exp]...]).
|
215 |
+
def forward(self, x):
|
216 |
+
x = F.relu(self.layer1(x))
|
217 |
+
x = F.relu(self.layer2(x))
|
218 |
+
x = F.relu(self.layer3(x))
|
219 |
+
x = F.relu(self.layer4(x))
|
220 |
+
return self.layer5(x)
|
221 |
+
|
222 |
+
|
223 |
+
# BATCH_SIZE is the number of transitions sampled from the replay buffer
|
224 |
+
# GAMMA is the discount factor as mentioned in the previous section
|
225 |
+
# EPS_START is the starting value of epsilon
|
226 |
+
# EPS_END is the final value of epsilon
|
227 |
+
# EPS_DECAY controls the rate of exponential decay of epsilon, higher means a slower decay
|
228 |
+
# TAU is the update rate of the target network
|
229 |
+
# LR is the learning rate of the AdamW optimizer
|
230 |
+
BATCH_SIZE = 512
|
231 |
+
GAMMA = 0.99
|
232 |
+
EPS_START = 1
|
233 |
+
EPS_END = 0.01
|
234 |
+
EPS_DECAY = 10000
|
235 |
+
TAU = 0.005
|
236 |
+
# LR = 1e-4
|
237 |
+
LR = 0.00025
|
238 |
+
|
239 |
+
# Get number of actions from gym action space
|
240 |
+
n_actions = env.action_space.n
|
241 |
+
state = env.reset()
|
242 |
+
n_observations = len(state)
|
243 |
+
|
244 |
+
policy_net = DQN(n_observations, n_actions).to(device)
|
245 |
+
target_net = DQN(n_observations, n_actions).to(device)
|
246 |
+
target_net.load_state_dict(policy_net.state_dict())
|
247 |
+
|
248 |
+
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
|
249 |
+
memory = ReplayMemory(10000)
|
250 |
+
|
251 |
+
|
252 |
+
steps_done = 0
|
253 |
+
|
254 |
+
|
255 |
+
def select_action(state):
|
256 |
+
global steps_done
|
257 |
+
sample = random.random()
|
258 |
+
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
|
259 |
+
steps_done += 1
|
260 |
+
if sample > eps_threshold:
|
261 |
+
with torch.no_grad():
|
262 |
+
# t.max(1) will return largest column value of each row.
|
263 |
+
# second column on max result is index of where max element was
|
264 |
+
# found, so we pick action with the larger expected reward.
|
265 |
+
return policy_net(state).max(1)[1].view(1, 1), eps_threshold
|
266 |
+
else:
|
267 |
+
return torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long), eps_threshold
|
268 |
+
|
269 |
+
|
270 |
+
episode_durations = []
|
271 |
+
|
272 |
+
|
273 |
+
def plot_durations(show_result=False):
|
274 |
+
plt.figure(1)
|
275 |
+
durations_t = torch.tensor(episode_durations, dtype=torch.float)
|
276 |
+
if show_result:
|
277 |
+
plt.title('Result')
|
278 |
+
else:
|
279 |
+
plt.clf()
|
280 |
+
plt.title('Training...')
|
281 |
+
plt.xlabel('Episode')
|
282 |
+
plt.ylabel('Duration')
|
283 |
+
plt.plot(durations_t.numpy())
|
284 |
+
# Take 100 episode averages and plot them too
|
285 |
+
if len(durations_t) >= 100:
|
286 |
+
means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
|
287 |
+
means = torch.cat((torch.zeros(99), means))
|
288 |
+
plt.plot(means.numpy())
|
289 |
+
|
290 |
+
plt.pause(0.001) # pause a bit so that plots are updated
|
291 |
+
if is_ipython:
|
292 |
+
if not show_result:
|
293 |
+
display.display(plt.gcf())
|
294 |
+
display.clear_output(wait=True)
|
295 |
+
else:
|
296 |
+
display.display(plt.gcf())
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
def optimize_model():
|
301 |
+
if len(memory) < BATCH_SIZE:
|
302 |
+
return
|
303 |
+
transitions = memory.sample(BATCH_SIZE)
|
304 |
+
# Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
|
305 |
+
# detailed explanation). This converts batch-array of Transitions
|
306 |
+
# to Transition of batch-arrays.
|
307 |
+
batch = Transition(*zip(*transitions))
|
308 |
+
|
309 |
+
# Compute a mask of non-final states and concatenate the batch elements
|
310 |
+
# (a final state would've been the one after which simulation ended)
|
311 |
+
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
|
312 |
+
batch.next_state)), device=device, dtype=torch.bool)
|
313 |
+
non_final_next_states = torch.cat([s for s in batch.next_state
|
314 |
+
if s is not None])
|
315 |
+
state_batch = torch.cat(batch.state)
|
316 |
+
action_batch = torch.cat(batch.action)
|
317 |
+
reward_batch = torch.cat(batch.reward)
|
318 |
+
|
319 |
+
# Compute Q(s_t, a) - the model computes Q(s_t), then we select the
|
320 |
+
# columns of actions taken. These are the actions which would've been taken
|
321 |
+
# for each batch state according to policy_net
|
322 |
+
state_action_values = policy_net(state_batch).gather(1, action_batch)
|
323 |
+
|
324 |
+
# Compute V(s_{t+1}) for all next states.
|
325 |
+
# Expected values of actions for non_final_next_states are computed based
|
326 |
+
# on the "older" target_net; selecting their best reward with max(1)[0].
|
327 |
+
# This is merged based on the mask, such that we'll have either the expected
|
328 |
+
# state value or 0 in case the state was final.
|
329 |
+
next_state_values = torch.zeros(BATCH_SIZE, device=device)
|
330 |
+
with torch.no_grad():
|
331 |
+
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
|
332 |
+
# Compute the expected Q values
|
333 |
+
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
|
334 |
+
|
335 |
+
# Compute Huber loss
|
336 |
+
criterion = nn.SmoothL1Loss()
|
337 |
+
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
|
338 |
+
|
339 |
+
# Optimize the model
|
340 |
+
optimizer.zero_grad()
|
341 |
+
loss.backward()
|
342 |
+
# In-place gradient clipping
|
343 |
+
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
|
344 |
+
optimizer.step()
|
345 |
+
|
346 |
+
|
347 |
+
with SummaryWriter() as writer:
|
348 |
+
if torch.cuda.is_available():
|
349 |
+
num_episodes = 600
|
350 |
+
else:
|
351 |
+
num_episodes = 50
|
352 |
+
epsilon = 1
|
353 |
+
episode_rewards = []
|
354 |
+
for i_episode in range(num_episodes):
|
355 |
+
|
356 |
+
# Initialize the environment and get it's state
|
357 |
+
state = env.reset()
|
358 |
+
state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
|
359 |
+
episode_reward = 0
|
360 |
+
for t in count():
|
361 |
+
action, epsilon = select_action(state)
|
362 |
+
observation, reward, done, info = env.step(action.item())
|
363 |
+
reward = torch.tensor([reward], device=device)
|
364 |
+
|
365 |
+
done = done or info["gameover"] == 1
|
366 |
+
if done:
|
367 |
+
episode_durations.append(t + 1)
|
368 |
+
print(f"Episode {i_episode} done")
|
369 |
+
# plot_durations()
|
370 |
+
break
|
371 |
+
# if done:
|
372 |
+
# next_state = None
|
373 |
+
# else:
|
374 |
+
# next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
|
375 |
+
|
376 |
+
next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
|
377 |
+
|
378 |
+
# Store the transition in memory
|
379 |
+
memory.push(state, action, next_state, reward)
|
380 |
+
episode_reward += reward
|
381 |
+
# Move to the next state
|
382 |
+
state = next_state
|
383 |
+
|
384 |
+
# Perform one step of the optimization (on the policy network)
|
385 |
+
optimize_model()
|
386 |
+
|
387 |
+
# Soft update of the target network's weights
|
388 |
+
# θ′ ← τ θ + (1 −τ )θ′
|
389 |
+
target_net_state_dict = target_net.state_dict()
|
390 |
+
policy_net_state_dict = policy_net.state_dict()
|
391 |
+
for key in policy_net_state_dict:
|
392 |
+
target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
|
393 |
+
target_net.load_state_dict(target_net_state_dict)
|
394 |
+
# if done:
|
395 |
+
# episode_durations.append(t + 1)
|
396 |
+
# # plot_durations()
|
397 |
+
# break
|
398 |
+
# episode_rewards.append(episode_reward)
|
399 |
+
writer.add_scalar("Rewards/Episode", episode_reward, i_episode)
|
400 |
+
writer.add_scalar("Epsilon", epsilon, i_episode)
|
401 |
+
writer.flush()
|
402 |
+
print('Complete')
|
403 |
+
plot_durations(show_result=True)
|
404 |
+
plt.ioff()
|
405 |
+
plt.show()
|
src/airstriker-genesis/utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
# Airstrikerラッパー
|
6 |
+
class AirstrikerDiscretizer(gym.ActionWrapper):
|
7 |
+
# 初期化
|
8 |
+
def __init__(self, env):
|
9 |
+
super(AirstrikerDiscretizer, self).__init__(env)
|
10 |
+
buttons = ['B', 'A', 'MODE', 'START', 'UP', 'DOWN', 'LEFT', 'RIGHT', 'C', 'Y', 'X', 'Z']
|
11 |
+
actions = [['LEFT'], ['RIGHT'], ['B']]
|
12 |
+
self._actions = []
|
13 |
+
for action in actions:
|
14 |
+
arr = np.array([False] * 12)
|
15 |
+
for button in action:
|
16 |
+
arr[buttons.index(button)] = True
|
17 |
+
self._actions.append(arr)
|
18 |
+
self.action_space = gym.spaces.Discrete(len(self._actions))
|
19 |
+
|
20 |
+
# 行動の取得
|
21 |
+
def action(self, a):
|
22 |
+
return self._actions[a].copy()
|
src/airstriker-genesis/wrappers.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
from collections import deque
|
4 |
+
import gym
|
5 |
+
from gym import spaces
|
6 |
+
import cv2
|
7 |
+
import retro
|
8 |
+
from utils import AirstrikerDiscretizer
|
9 |
+
|
10 |
+
|
11 |
+
'''
|
12 |
+
Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
13 |
+
'''
|
14 |
+
|
15 |
+
|
16 |
+
class LazyFrames(object):
|
17 |
+
def __init__(self, frames):
|
18 |
+
"""This object ensures that common frames between the observations are only stored once.
|
19 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
20 |
+
buffers.
|
21 |
+
This object should only be converted to numpy array before being passed to the model.
|
22 |
+
You'd not believe how complex the previous solution was."""
|
23 |
+
self._frames = frames
|
24 |
+
self._out = None
|
25 |
+
|
26 |
+
def _force(self):
|
27 |
+
if self._out is None:
|
28 |
+
self._out = np.concatenate(self._frames, axis=2)
|
29 |
+
self._frames = None
|
30 |
+
return self._out
|
31 |
+
|
32 |
+
def __array__(self, dtype=None):
|
33 |
+
out = self._force()
|
34 |
+
if dtype is not None:
|
35 |
+
out = out.astype(dtype)
|
36 |
+
return out
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self._force())
|
40 |
+
|
41 |
+
def __getitem__(self, i):
|
42 |
+
return self._force()[i]
|
43 |
+
|
44 |
+
class FireResetEnv(gym.Wrapper):
|
45 |
+
def __init__(self, env):
|
46 |
+
"""Take action on reset for environments that are fixed until firing."""
|
47 |
+
gym.Wrapper.__init__(self, env)
|
48 |
+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
49 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
50 |
+
|
51 |
+
def reset(self, **kwargs):
|
52 |
+
self.env.reset(**kwargs)
|
53 |
+
obs, _, done, _ = self.env.step(1)
|
54 |
+
if done:
|
55 |
+
self.env.reset(**kwargs)
|
56 |
+
obs, _, done, _ = self.env.step(2)
|
57 |
+
if done:
|
58 |
+
self.env.reset(**kwargs)
|
59 |
+
return obs
|
60 |
+
|
61 |
+
def step(self, ac):
|
62 |
+
return self.env.step(ac)
|
63 |
+
|
64 |
+
|
65 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
66 |
+
def __init__(self, env, skip=4):
|
67 |
+
"""Return only every `skip`-th frame"""
|
68 |
+
gym.Wrapper.__init__(self, env)
|
69 |
+
# most recent raw observations (for max pooling across time steps)
|
70 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
71 |
+
self._skip = skip
|
72 |
+
|
73 |
+
def step(self, action):
|
74 |
+
"""Repeat action, sum reward, and max over last observations."""
|
75 |
+
total_reward = 0.0
|
76 |
+
done = None
|
77 |
+
for i in range(self._skip):
|
78 |
+
obs, reward, done, info = self.env.step(action)
|
79 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
80 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
81 |
+
total_reward += reward
|
82 |
+
if done:
|
83 |
+
break
|
84 |
+
# Note that the observation on the done=True frame
|
85 |
+
# doesn't matter
|
86 |
+
max_frame = self._obs_buffer.max(axis=0)
|
87 |
+
|
88 |
+
return max_frame, total_reward, done, info
|
89 |
+
|
90 |
+
def reset(self, **kwargs):
|
91 |
+
return self.env.reset(**kwargs)
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
class WarpFrame(gym.ObservationWrapper):
|
96 |
+
def __init__(self, env):
|
97 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
98 |
+
gym.ObservationWrapper.__init__(self, env)
|
99 |
+
self.width = 84
|
100 |
+
self.height = 84
|
101 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
102 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
103 |
+
|
104 |
+
def observation(self, frame):
|
105 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
106 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
107 |
+
return frame[:, :, None]
|
108 |
+
|
109 |
+
class WarpFrameNoResize(gym.ObservationWrapper):
|
110 |
+
def __init__(self, env):
|
111 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
112 |
+
gym.ObservationWrapper.__init__(self, env)
|
113 |
+
|
114 |
+
def observation(self, frame):
|
115 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
116 |
+
# frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
117 |
+
return frame[:, :, None]
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
class FrameStack(gym.Wrapper):
|
122 |
+
def __init__(self, env, k):
|
123 |
+
"""Stack k last frames.
|
124 |
+
Returns lazy array, which is much more memory efficient.
|
125 |
+
See Also
|
126 |
+
--------
|
127 |
+
baselines.common.atari_wrappers.LazyFrames
|
128 |
+
"""
|
129 |
+
gym.Wrapper.__init__(self, env)
|
130 |
+
self.k = k
|
131 |
+
self.frames = deque([], maxlen=k)
|
132 |
+
shp = env.observation_space.shape
|
133 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
134 |
+
|
135 |
+
def reset(self):
|
136 |
+
ob = self.env.reset()
|
137 |
+
for _ in range(self.k):
|
138 |
+
self.frames.append(ob)
|
139 |
+
return self._get_ob()
|
140 |
+
|
141 |
+
def step(self, action):
|
142 |
+
ob, reward, done, info = self.env.step(action)
|
143 |
+
self.frames.append(ob)
|
144 |
+
return self._get_ob(), reward, done, info
|
145 |
+
|
146 |
+
def _get_ob(self):
|
147 |
+
assert len(self.frames) == self.k
|
148 |
+
return LazyFrames(list(self.frames))
|
149 |
+
|
150 |
+
|
151 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
152 |
+
def __init__(self, env):
|
153 |
+
super(ImageToPyTorch, self).__init__(env)
|
154 |
+
old_shape = self.observation_space.shape
|
155 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
156 |
+
|
157 |
+
def observation(self, observation):
|
158 |
+
return np.moveaxis(observation, 2, 0)
|
159 |
+
|
160 |
+
|
161 |
+
# class ImageToPyTorch(gym.ObservationWrapper):
|
162 |
+
# def __init__(self, env):
|
163 |
+
# super(ImageToPyTorch, self).__init__(env)
|
164 |
+
# old_shape = self.observation_space.shape
|
165 |
+
# new_shape = (old_shape[-1], old_shape[0], old_shape[1])
|
166 |
+
# print("Old: ", old_shape)
|
167 |
+
# print("New: ", new_shape)
|
168 |
+
# self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=new_shape, dtype=np.float32)
|
169 |
+
|
170 |
+
# def observation(self, observation):
|
171 |
+
# return np.moveaxis(observation, 2, 0)
|
172 |
+
|
173 |
+
|
174 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
175 |
+
def __init__(self, env):
|
176 |
+
gym.ObservationWrapper.__init__(self, env)
|
177 |
+
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
178 |
+
|
179 |
+
def observation(self, observation):
|
180 |
+
# careful! This undoes the memory optimization, use
|
181 |
+
# with smaller replay buffers only.
|
182 |
+
return np.array(observation).astype(np.float32) / 255.0
|
183 |
+
|
184 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
185 |
+
def __init__(self, env):
|
186 |
+
gym.RewardWrapper.__init__(self, env)
|
187 |
+
|
188 |
+
def reward(self, reward):
|
189 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
190 |
+
return np.sign(reward)
|
191 |
+
|
192 |
+
|
193 |
+
def make_env():
|
194 |
+
|
195 |
+
env = retro.make(game='Airstriker-Genesis')
|
196 |
+
env = MaxAndSkipEnv(env) ## Return only every `skip`-th frame
|
197 |
+
env = WarpFrame(env) ## Reshape image
|
198 |
+
env = ImageToPyTorch(env) ## Invert shape
|
199 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
200 |
+
env = ScaledFloatFrame(env) ## Scale frames
|
201 |
+
env = AirstrikerDiscretizer(env)
|
202 |
+
env = ClipRewardEnv(env)
|
203 |
+
return env
|
204 |
+
|
205 |
+
def make_starpilot(render=False):
|
206 |
+
if render:
|
207 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy", render_mode="human")
|
208 |
+
else:
|
209 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy")
|
210 |
+
env = WarpFrameNoResize(env) ## Reshape image
|
211 |
+
env = ImageToPyTorch(env) ## Invert shape
|
212 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
213 |
+
return env
|
src/lunar-lander/agent.py
ADDED
@@ -0,0 +1,1104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch.nn as nn
|
5 |
+
import copy
|
6 |
+
import time, datetime
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from collections import deque
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
|
11 |
+
|
12 |
+
class DQNet(nn.Module):
|
13 |
+
"""mini cnn structure"""
|
14 |
+
|
15 |
+
def __init__(self, input_dim, output_dim):
|
16 |
+
super().__init__()
|
17 |
+
|
18 |
+
self.online = nn.Sequential(
|
19 |
+
nn.Linear(input_dim, 150),
|
20 |
+
nn.ReLU(),
|
21 |
+
nn.Linear(150, 120),
|
22 |
+
nn.ReLU(),
|
23 |
+
nn.Linear(120, output_dim),
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
self.target = copy.deepcopy(self.online)
|
28 |
+
|
29 |
+
# Q_target parameters are frozen.
|
30 |
+
for p in self.target.parameters():
|
31 |
+
p.requires_grad = False
|
32 |
+
|
33 |
+
def forward(self, input, model):
|
34 |
+
if model == "online":
|
35 |
+
return self.online(input)
|
36 |
+
elif model == "target":
|
37 |
+
return self.target(input)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
class MetricLogger:
|
42 |
+
def __init__(self, save_dir):
|
43 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
44 |
+
self.save_log = save_dir / "log"
|
45 |
+
with open(self.save_log, "w") as f:
|
46 |
+
f.write(
|
47 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
48 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
49 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
50 |
+
)
|
51 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
52 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
53 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
54 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
55 |
+
|
56 |
+
# History metrics
|
57 |
+
self.ep_rewards = []
|
58 |
+
self.ep_lengths = []
|
59 |
+
self.ep_avg_losses = []
|
60 |
+
self.ep_avg_qs = []
|
61 |
+
|
62 |
+
# Moving averages, added for every call to record()
|
63 |
+
self.moving_avg_ep_rewards = []
|
64 |
+
self.moving_avg_ep_lengths = []
|
65 |
+
self.moving_avg_ep_avg_losses = []
|
66 |
+
self.moving_avg_ep_avg_qs = []
|
67 |
+
|
68 |
+
# Current episode metric
|
69 |
+
self.init_episode()
|
70 |
+
|
71 |
+
# Timing
|
72 |
+
self.record_time = time.time()
|
73 |
+
|
74 |
+
def log_step(self, reward, loss, q):
|
75 |
+
self.curr_ep_reward += reward
|
76 |
+
self.curr_ep_length += 1
|
77 |
+
if loss:
|
78 |
+
self.curr_ep_loss += loss
|
79 |
+
self.curr_ep_q += q
|
80 |
+
self.curr_ep_loss_length += 1
|
81 |
+
|
82 |
+
def log_episode(self, episode_number):
|
83 |
+
"Mark end of episode"
|
84 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
85 |
+
self.ep_lengths.append(self.curr_ep_length)
|
86 |
+
if self.curr_ep_loss_length == 0:
|
87 |
+
ep_avg_loss = 0
|
88 |
+
ep_avg_q = 0
|
89 |
+
else:
|
90 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
91 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
92 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
93 |
+
self.ep_avg_qs.append(ep_avg_q)
|
94 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
95 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
96 |
+
self.writer.flush()
|
97 |
+
self.init_episode()
|
98 |
+
|
99 |
+
def init_episode(self):
|
100 |
+
self.curr_ep_reward = 0.0
|
101 |
+
self.curr_ep_length = 0
|
102 |
+
self.curr_ep_loss = 0.0
|
103 |
+
self.curr_ep_q = 0.0
|
104 |
+
self.curr_ep_loss_length = 0
|
105 |
+
|
106 |
+
def record(self, episode, epsilon, step):
|
107 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
108 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
109 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
110 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
111 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
112 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
113 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
114 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
115 |
+
|
116 |
+
last_record_time = self.record_time
|
117 |
+
self.record_time = time.time()
|
118 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
119 |
+
|
120 |
+
print(
|
121 |
+
f"Episode {episode} - "
|
122 |
+
f"Step {step} - "
|
123 |
+
f"Epsilon {epsilon} - "
|
124 |
+
f"Mean Reward {mean_ep_reward} - "
|
125 |
+
f"Mean Length {mean_ep_length} - "
|
126 |
+
f"Mean Loss {mean_ep_loss} - "
|
127 |
+
f"Mean Q Value {mean_ep_q} - "
|
128 |
+
f"Time Delta {time_since_last_record} - "
|
129 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
130 |
+
)
|
131 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
132 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
133 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
134 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
135 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
136 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
137 |
+
self.writer.flush()
|
138 |
+
with open(self.save_log, "a") as f:
|
139 |
+
f.write(
|
140 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
141 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
142 |
+
f"{time_since_last_record:15.3f}"
|
143 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
144 |
+
)
|
145 |
+
|
146 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
147 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
148 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
149 |
+
plt.clf()
|
150 |
+
|
151 |
+
|
152 |
+
class DQNAgent:
|
153 |
+
def __init__(self,
|
154 |
+
state_dim,
|
155 |
+
action_dim,
|
156 |
+
save_dir,
|
157 |
+
checkpoint=None,
|
158 |
+
learning_rate=0.00025,
|
159 |
+
max_memory_size=100000,
|
160 |
+
batch_size=32,
|
161 |
+
exploration_rate=1,
|
162 |
+
exploration_rate_decay=0.9999999,
|
163 |
+
exploration_rate_min=0.1,
|
164 |
+
training_frequency=1,
|
165 |
+
learning_starts=1000,
|
166 |
+
target_network_sync_frequency=500,
|
167 |
+
reset_exploration_rate=False,
|
168 |
+
save_frequency=100000,
|
169 |
+
gamma=0.9,
|
170 |
+
load_replay_buffer=True):
|
171 |
+
self.state_dim = state_dim
|
172 |
+
self.action_dim = action_dim
|
173 |
+
self.max_memory_size = max_memory_size
|
174 |
+
self.memory = deque(maxlen=max_memory_size)
|
175 |
+
self.batch_size = batch_size
|
176 |
+
|
177 |
+
self.exploration_rate = exploration_rate
|
178 |
+
self.exploration_rate_decay = exploration_rate_decay
|
179 |
+
self.exploration_rate_min = exploration_rate_min
|
180 |
+
self.gamma = gamma
|
181 |
+
|
182 |
+
self.curr_step = 0
|
183 |
+
self.learning_starts = learning_starts # min. experiences before training
|
184 |
+
|
185 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
186 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
187 |
+
|
188 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
189 |
+
self.save_dir = save_dir
|
190 |
+
|
191 |
+
self.use_cuda = torch.cuda.is_available()
|
192 |
+
|
193 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
194 |
+
if self.use_cuda:
|
195 |
+
self.net = self.net.to(device='cuda')
|
196 |
+
if checkpoint:
|
197 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
198 |
+
|
199 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
200 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
201 |
+
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
|
202 |
+
# self.loss_fn = torch.nn.MSELoss()
|
203 |
+
|
204 |
+
|
205 |
+
def act(self, state):
|
206 |
+
"""
|
207 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
208 |
+
|
209 |
+
Inputs:
|
210 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
211 |
+
Outputs:
|
212 |
+
action_idx (int): An integer representing which action the agent will perform
|
213 |
+
"""
|
214 |
+
# EXPLORE
|
215 |
+
if np.random.rand() < self.exploration_rate:
|
216 |
+
action_idx = np.random.randint(self.action_dim)
|
217 |
+
|
218 |
+
# EXPLOIT
|
219 |
+
else:
|
220 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
221 |
+
state = state.unsqueeze(0)
|
222 |
+
action_values = self.net(state, model='online')
|
223 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
224 |
+
|
225 |
+
# decrease exploration_rate
|
226 |
+
|
227 |
+
self.exploration_rate *= self.exploration_rate_decay
|
228 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
229 |
+
|
230 |
+
# increment step
|
231 |
+
self.curr_step += 1
|
232 |
+
return action_idx
|
233 |
+
|
234 |
+
def cache(self, state, next_state, action, reward, done):
|
235 |
+
"""
|
236 |
+
Store the experience to self.memory (replay buffer)
|
237 |
+
|
238 |
+
Inputs:
|
239 |
+
state (LazyFrame),
|
240 |
+
next_state (LazyFrame),
|
241 |
+
action (int),
|
242 |
+
reward (float),
|
243 |
+
done(bool))
|
244 |
+
"""
|
245 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
246 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
247 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
248 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
249 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
250 |
+
|
251 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
252 |
+
|
253 |
+
|
254 |
+
def recall(self):
|
255 |
+
"""
|
256 |
+
Retrieve a batch of experiences from memory
|
257 |
+
"""
|
258 |
+
batch = random.sample(self.memory, self.batch_size)
|
259 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
260 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
261 |
+
|
262 |
+
|
263 |
+
def td_estimate(self, states, actions):
|
264 |
+
actions = actions.reshape(-1, 1)
|
265 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
266 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
267 |
+
return predicted_qs
|
268 |
+
|
269 |
+
|
270 |
+
@torch.no_grad()
|
271 |
+
def td_target(self, rewards, next_states, dones):
|
272 |
+
rewards = rewards.reshape(-1, 1)
|
273 |
+
dones = dones.reshape(-1, 1)
|
274 |
+
target_qs = self.net(next_states, model='target')
|
275 |
+
target_qs = torch.max(target_qs, dim=1).values
|
276 |
+
target_qs = target_qs.reshape(-1, 1)
|
277 |
+
target_qs[dones] = 0.0
|
278 |
+
return (rewards + (self.gamma * target_qs))
|
279 |
+
|
280 |
+
def update_Q_online(self, td_estimate, td_target) :
|
281 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
282 |
+
self.optimizer.zero_grad()
|
283 |
+
loss.backward()
|
284 |
+
self.optimizer.step()
|
285 |
+
return loss.item()
|
286 |
+
|
287 |
+
|
288 |
+
def sync_Q_target(self):
|
289 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
290 |
+
|
291 |
+
|
292 |
+
def learn(self):
|
293 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
294 |
+
self.sync_Q_target()
|
295 |
+
|
296 |
+
if self.curr_step % self.save_every == 0:
|
297 |
+
self.save()
|
298 |
+
|
299 |
+
if self.curr_step < self.learning_starts:
|
300 |
+
return None, None
|
301 |
+
|
302 |
+
if self.curr_step % self.training_frequency != 0:
|
303 |
+
return None, None
|
304 |
+
|
305 |
+
# Sample from memory
|
306 |
+
state, next_state, action, reward, done = self.recall()
|
307 |
+
|
308 |
+
# Get TD Estimate
|
309 |
+
td_est = self.td_estimate(state, action)
|
310 |
+
|
311 |
+
# Get TD Target
|
312 |
+
td_tgt = self.td_target(reward, next_state, done)
|
313 |
+
|
314 |
+
# Backpropagate loss through Q_online
|
315 |
+
|
316 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
317 |
+
|
318 |
+
return (td_est.mean().item(), loss)
|
319 |
+
|
320 |
+
|
321 |
+
def save(self):
|
322 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
323 |
+
torch.save(
|
324 |
+
dict(
|
325 |
+
model=self.net.state_dict(),
|
326 |
+
exploration_rate=self.exploration_rate,
|
327 |
+
replay_memory=self.memory
|
328 |
+
),
|
329 |
+
save_path
|
330 |
+
)
|
331 |
+
|
332 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
333 |
+
|
334 |
+
|
335 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
336 |
+
if not load_path.exists():
|
337 |
+
raise ValueError(f"{load_path} does not exist")
|
338 |
+
|
339 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
340 |
+
exploration_rate = ckp.get('exploration_rate')
|
341 |
+
state_dict = ckp.get('model')
|
342 |
+
|
343 |
+
|
344 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
345 |
+
self.net.load_state_dict(state_dict)
|
346 |
+
|
347 |
+
if load_replay_buffer:
|
348 |
+
replay_memory = ckp.get('replay_memory')
|
349 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
350 |
+
self.memory = replay_memory if replay_memory else self.memory
|
351 |
+
|
352 |
+
if reset_exploration_rate:
|
353 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
354 |
+
else:
|
355 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
356 |
+
self.exploration_rate = exploration_rate
|
357 |
+
|
358 |
+
|
359 |
+
class DDQNAgent(DQNAgent):
|
360 |
+
@torch.no_grad()
|
361 |
+
def td_target(self, rewards, next_states, dones):
|
362 |
+
rewards = rewards.reshape(-1, 1)
|
363 |
+
dones = dones.reshape(-1, 1)
|
364 |
+
q_vals = self.net(next_states, model='online')
|
365 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
366 |
+
target_actions = target_actions.reshape(-1, 1)
|
367 |
+
|
368 |
+
target_qs = self.net(next_states, model='target')
|
369 |
+
target_qs = target_qs.gather(1, target_actions)
|
370 |
+
target_qs = target_qs.reshape(-1, 1)
|
371 |
+
target_qs[dones] = 0.0
|
372 |
+
return (rewards + (self.gamma * target_qs))
|
373 |
+
|
374 |
+
|
375 |
+
class DuelingDQNet(nn.Module):
|
376 |
+
def __init__(self, input_dim, output_dim):
|
377 |
+
super().__init__()
|
378 |
+
self.feature_layer = nn.Sequential(
|
379 |
+
nn.Linear(input_dim, 150),
|
380 |
+
nn.ReLU(),
|
381 |
+
nn.Linear(150, 120),
|
382 |
+
nn.ReLU()
|
383 |
+
)
|
384 |
+
|
385 |
+
self.value_layer = nn.Sequential(
|
386 |
+
nn.Linear(120, 120),
|
387 |
+
nn.ReLU(),
|
388 |
+
nn.Linear(120, 1)
|
389 |
+
)
|
390 |
+
|
391 |
+
self.advantage_layer = nn.Sequential(
|
392 |
+
nn.Linear(120, 120),
|
393 |
+
nn.ReLU(),
|
394 |
+
nn.Linear(120, output_dim)
|
395 |
+
)
|
396 |
+
|
397 |
+
def forward(self, state):
|
398 |
+
feature_output = self.feature_layer(state)
|
399 |
+
# feature_output = feature_output.view(feature_output.size(0), -1)
|
400 |
+
value = self.value_layer(feature_output)
|
401 |
+
advantage = self.advantage_layer(feature_output)
|
402 |
+
q_value = value + (advantage - advantage.mean())
|
403 |
+
|
404 |
+
return q_value
|
405 |
+
|
406 |
+
|
407 |
+
class DuelingDQNAgent:
|
408 |
+
def __init__(self,
|
409 |
+
state_dim,
|
410 |
+
action_dim,
|
411 |
+
save_dir,
|
412 |
+
checkpoint=None,
|
413 |
+
learning_rate=0.00025,
|
414 |
+
max_memory_size=100000,
|
415 |
+
batch_size=32,
|
416 |
+
exploration_rate=1,
|
417 |
+
exploration_rate_decay=0.9999999,
|
418 |
+
exploration_rate_min=0.1,
|
419 |
+
training_frequency=1,
|
420 |
+
learning_starts=1000,
|
421 |
+
target_network_sync_frequency=500,
|
422 |
+
reset_exploration_rate=False,
|
423 |
+
save_frequency=100000,
|
424 |
+
gamma=0.9,
|
425 |
+
load_replay_buffer=True):
|
426 |
+
self.state_dim = state_dim
|
427 |
+
self.action_dim = action_dim
|
428 |
+
self.max_memory_size = max_memory_size
|
429 |
+
self.memory = deque(maxlen=max_memory_size)
|
430 |
+
self.batch_size = batch_size
|
431 |
+
|
432 |
+
self.exploration_rate = exploration_rate
|
433 |
+
self.exploration_rate_decay = exploration_rate_decay
|
434 |
+
self.exploration_rate_min = exploration_rate_min
|
435 |
+
self.gamma = gamma
|
436 |
+
|
437 |
+
self.curr_step = 0
|
438 |
+
self.learning_starts = learning_starts # min. experiences before training
|
439 |
+
|
440 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
441 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
442 |
+
|
443 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
444 |
+
self.save_dir = save_dir
|
445 |
+
|
446 |
+
self.use_cuda = torch.cuda.is_available()
|
447 |
+
|
448 |
+
|
449 |
+
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
|
450 |
+
self.target_net = copy.deepcopy(self.online_net)
|
451 |
+
# Q_target parameters are frozen.
|
452 |
+
for p in self.target_net.parameters():
|
453 |
+
p.requires_grad = False
|
454 |
+
|
455 |
+
if self.use_cuda:
|
456 |
+
self.online_net = self.online_net(device='cuda')
|
457 |
+
self.target_net = self.target_net(device='cuda')
|
458 |
+
if checkpoint:
|
459 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
460 |
+
|
461 |
+
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
|
462 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
463 |
+
# self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
|
464 |
+
# self.loss_fn = torch.nn.MSELoss()
|
465 |
+
|
466 |
+
|
467 |
+
def act(self, state):
|
468 |
+
"""
|
469 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
470 |
+
|
471 |
+
Inputs:
|
472 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
473 |
+
Outputs:
|
474 |
+
action_idx (int): An integer representing which action the agent will perform
|
475 |
+
"""
|
476 |
+
# EXPLORE
|
477 |
+
if np.random.rand() < self.exploration_rate:
|
478 |
+
action_idx = np.random.randint(self.action_dim)
|
479 |
+
|
480 |
+
# EXPLOIT
|
481 |
+
else:
|
482 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
483 |
+
state = state.unsqueeze(0)
|
484 |
+
action_values = self.online_net(state)
|
485 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
486 |
+
|
487 |
+
# decrease exploration_rate
|
488 |
+
self.exploration_rate *= self.exploration_rate_decay
|
489 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
490 |
+
|
491 |
+
# increment step
|
492 |
+
self.curr_step += 1
|
493 |
+
return action_idx
|
494 |
+
|
495 |
+
def cache(self, state, next_state, action, reward, done):
|
496 |
+
"""
|
497 |
+
Store the experience to self.memory (replay buffer)
|
498 |
+
|
499 |
+
Inputs:
|
500 |
+
state (LazyFrame),
|
501 |
+
next_state (LazyFrame),
|
502 |
+
action (int),
|
503 |
+
reward (float),
|
504 |
+
done(bool))
|
505 |
+
"""
|
506 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
507 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
508 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
509 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
510 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
511 |
+
|
512 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
513 |
+
|
514 |
+
|
515 |
+
def recall(self):
|
516 |
+
"""
|
517 |
+
Retrieve a batch of experiences from memory
|
518 |
+
"""
|
519 |
+
batch = random.sample(self.memory, self.batch_size)
|
520 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
521 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
522 |
+
|
523 |
+
|
524 |
+
def td_estimate(self, states, actions):
|
525 |
+
actions = actions.reshape(-1, 1)
|
526 |
+
predicted_qs = self.online_net(states)# Q_online(s,a)
|
527 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
528 |
+
return predicted_qs
|
529 |
+
|
530 |
+
|
531 |
+
@torch.no_grad()
|
532 |
+
def td_target(self, rewards, next_states, dones):
|
533 |
+
rewards = rewards.reshape(-1, 1)
|
534 |
+
dones = dones.reshape(-1, 1)
|
535 |
+
target_qs = self.target_net.forward(next_states)
|
536 |
+
target_qs = torch.max(target_qs, dim=1).values
|
537 |
+
target_qs = target_qs.reshape(-1, 1)
|
538 |
+
target_qs[dones] = 0.0
|
539 |
+
return (rewards + (self.gamma * target_qs))
|
540 |
+
|
541 |
+
def update_Q_online(self, td_estimate, td_target) :
|
542 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
543 |
+
self.optimizer.zero_grad()
|
544 |
+
loss.backward()
|
545 |
+
self.optimizer.step()
|
546 |
+
return loss.item()
|
547 |
+
|
548 |
+
|
549 |
+
def sync_Q_target(self):
|
550 |
+
self.target_net.load_state_dict(self.online_net.state_dict())
|
551 |
+
|
552 |
+
|
553 |
+
def learn(self):
|
554 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
555 |
+
self.sync_Q_target()
|
556 |
+
|
557 |
+
if self.curr_step % self.save_every == 0:
|
558 |
+
self.save()
|
559 |
+
|
560 |
+
if self.curr_step < self.learning_starts:
|
561 |
+
return None, None
|
562 |
+
|
563 |
+
if self.curr_step % self.training_frequency != 0:
|
564 |
+
return None, None
|
565 |
+
|
566 |
+
# Sample from memory
|
567 |
+
state, next_state, action, reward, done = self.recall()
|
568 |
+
|
569 |
+
# Get TD Estimate
|
570 |
+
td_est = self.td_estimate(state, action)
|
571 |
+
|
572 |
+
# Get TD Target
|
573 |
+
td_tgt = self.td_target(reward, next_state, done)
|
574 |
+
|
575 |
+
# Backpropagate loss through Q_online
|
576 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
577 |
+
|
578 |
+
return (td_est.mean().item(), loss)
|
579 |
+
|
580 |
+
|
581 |
+
def save(self):
|
582 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
583 |
+
torch.save(
|
584 |
+
dict(
|
585 |
+
model=self.online_net.state_dict(),
|
586 |
+
exploration_rate=self.exploration_rate,
|
587 |
+
replay_memory=self.memory
|
588 |
+
),
|
589 |
+
save_path
|
590 |
+
)
|
591 |
+
|
592 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
593 |
+
|
594 |
+
|
595 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
596 |
+
if not load_path.exists():
|
597 |
+
raise ValueError(f"{load_path} does not exist")
|
598 |
+
|
599 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
600 |
+
exploration_rate = ckp.get('exploration_rate')
|
601 |
+
state_dict = ckp.get('model')
|
602 |
+
|
603 |
+
|
604 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
605 |
+
self.online_net.load_state_dict(state_dict)
|
606 |
+
self.target_net = copy.deepcopy(self.online_net)
|
607 |
+
self.sync_Q_target()
|
608 |
+
|
609 |
+
if load_replay_buffer:
|
610 |
+
replay_memory = ckp.get('replay_memory')
|
611 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
612 |
+
self.memory = replay_memory if replay_memory else self.memory
|
613 |
+
|
614 |
+
if reset_exploration_rate:
|
615 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
616 |
+
else:
|
617 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
618 |
+
self.exploration_rate = exploration_rate
|
619 |
+
|
620 |
+
|
621 |
+
|
622 |
+
|
623 |
+
class DuelingDDQNAgent(DuelingDQNAgent):
|
624 |
+
@torch.no_grad()
|
625 |
+
def td_target(self, rewards, next_states, dones):
|
626 |
+
rewards = rewards.reshape(-1, 1)
|
627 |
+
dones = dones.reshape(-1, 1)
|
628 |
+
q_vals = self.online_net.forward(next_states)
|
629 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
630 |
+
target_actions = target_actions.reshape(-1, 1)
|
631 |
+
|
632 |
+
target_qs = self.target_net.forward(next_states)
|
633 |
+
target_qs = target_qs.gather(1, target_actions)
|
634 |
+
target_qs = target_qs.reshape(-1, 1)
|
635 |
+
target_qs[dones] = 0.0
|
636 |
+
return (rewards + (self.gamma * target_qs))
|
637 |
+
|
638 |
+
|
639 |
+
|
640 |
+
class DQNAgentWithStepDecay:
|
641 |
+
def __init__(self,
|
642 |
+
state_dim,
|
643 |
+
action_dim,
|
644 |
+
save_dir,
|
645 |
+
checkpoint=None,
|
646 |
+
learning_rate=0.00025,
|
647 |
+
max_memory_size=100000,
|
648 |
+
batch_size=32,
|
649 |
+
exploration_rate=1,
|
650 |
+
exploration_rate_decay=0.9999999,
|
651 |
+
exploration_rate_min=0.1,
|
652 |
+
training_frequency=1,
|
653 |
+
learning_starts=1000,
|
654 |
+
target_network_sync_frequency=500,
|
655 |
+
reset_exploration_rate=False,
|
656 |
+
save_frequency=100000,
|
657 |
+
gamma=0.9,
|
658 |
+
load_replay_buffer=True):
|
659 |
+
self.state_dim = state_dim
|
660 |
+
self.action_dim = action_dim
|
661 |
+
self.max_memory_size = max_memory_size
|
662 |
+
self.memory = deque(maxlen=max_memory_size)
|
663 |
+
self.batch_size = batch_size
|
664 |
+
|
665 |
+
self.exploration_rate = exploration_rate
|
666 |
+
self.exploration_rate_decay = exploration_rate_decay
|
667 |
+
self.exploration_rate_min = exploration_rate_min
|
668 |
+
self.gamma = gamma
|
669 |
+
|
670 |
+
self.curr_step = 0
|
671 |
+
self.learning_starts = learning_starts # min. experiences before training
|
672 |
+
|
673 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
674 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
675 |
+
|
676 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
677 |
+
self.save_dir = save_dir
|
678 |
+
|
679 |
+
self.use_cuda = torch.cuda.is_available()
|
680 |
+
|
681 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
682 |
+
if self.use_cuda:
|
683 |
+
self.net = self.net.to(device='cuda')
|
684 |
+
if checkpoint:
|
685 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
686 |
+
|
687 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
688 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
689 |
+
# self.optimizer = torch.optim.Adam(self.net.parameters(), lr=learning_rate)
|
690 |
+
# self.loss_fn = torch.nn.MSELoss()
|
691 |
+
|
692 |
+
|
693 |
+
def act(self, state):
|
694 |
+
"""
|
695 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
696 |
+
|
697 |
+
Inputs:
|
698 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
699 |
+
Outputs:
|
700 |
+
action_idx (int): An integer representing which action the agent will perform
|
701 |
+
"""
|
702 |
+
# EXPLORE
|
703 |
+
if np.random.rand() < self.exploration_rate:
|
704 |
+
action_idx = np.random.randint(self.action_dim)
|
705 |
+
|
706 |
+
# EXPLOIT
|
707 |
+
else:
|
708 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
709 |
+
state = state.unsqueeze(0)
|
710 |
+
action_values = self.net(state, model='online')
|
711 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
712 |
+
|
713 |
+
# decrease exploration_rate
|
714 |
+
|
715 |
+
self.exploration_rate *= self.exploration_rate_decay
|
716 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
717 |
+
|
718 |
+
# increment step
|
719 |
+
self.curr_step += 1
|
720 |
+
return action_idx
|
721 |
+
|
722 |
+
def cache(self, state, next_state, action, reward, done, stepnumber):
|
723 |
+
"""
|
724 |
+
Store the experience to self.memory (replay buffer)
|
725 |
+
|
726 |
+
Inputs:
|
727 |
+
state (LazyFrame),
|
728 |
+
next_state (LazyFrame),
|
729 |
+
action (int),
|
730 |
+
reward (float),
|
731 |
+
done(bool))
|
732 |
+
"""
|
733 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
734 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
735 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
736 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
737 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
738 |
+
stepnumber = torch.LongTensor([stepnumber]).cuda() if self.use_cuda else torch.LongTensor([stepnumber])
|
739 |
+
|
740 |
+
self.memory.append( (state, next_state, action, reward, done, stepnumber) )
|
741 |
+
|
742 |
+
|
743 |
+
def recall(self):
|
744 |
+
"""
|
745 |
+
Retrieve a batch of experiences from memory
|
746 |
+
"""
|
747 |
+
batch = random.sample(self.memory, self.batch_size)
|
748 |
+
state, next_state, action, reward, done, stepnumber = map(torch.stack, zip(*batch))
|
749 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze(), stepnumber.squeeze()
|
750 |
+
|
751 |
+
|
752 |
+
def td_estimate(self, states, actions):
|
753 |
+
actions = actions.reshape(-1, 1)
|
754 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
755 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
756 |
+
return predicted_qs
|
757 |
+
|
758 |
+
|
759 |
+
@torch.no_grad()
|
760 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
761 |
+
rewards = rewards.reshape(-1, 1)
|
762 |
+
dones = dones.reshape(-1, 1)
|
763 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
764 |
+
target_qs = self.net(next_states, model='target')
|
765 |
+
target_qs = torch.max(target_qs, dim=1).values
|
766 |
+
target_qs = target_qs.reshape(-1, 1)
|
767 |
+
target_qs[dones] = 0.0
|
768 |
+
discount = ((200 - stepnumbers)/200)
|
769 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
770 |
+
return (rewards + val)
|
771 |
+
|
772 |
+
def update_Q_online(self, td_estimate, td_target) :
|
773 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
774 |
+
self.optimizer.zero_grad()
|
775 |
+
loss.backward()
|
776 |
+
self.optimizer.step()
|
777 |
+
return loss.item()
|
778 |
+
|
779 |
+
|
780 |
+
def sync_Q_target(self):
|
781 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
782 |
+
|
783 |
+
|
784 |
+
def learn(self):
|
785 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
786 |
+
self.sync_Q_target()
|
787 |
+
|
788 |
+
if self.curr_step % self.save_every == 0:
|
789 |
+
self.save()
|
790 |
+
|
791 |
+
if self.curr_step < self.learning_starts:
|
792 |
+
return None, None
|
793 |
+
|
794 |
+
if self.curr_step % self.training_frequency != 0:
|
795 |
+
return None, None
|
796 |
+
|
797 |
+
# Sample from memory
|
798 |
+
state, next_state, action, reward, done, stepnumber = self.recall()
|
799 |
+
|
800 |
+
# Get TD Estimate
|
801 |
+
td_est = self.td_estimate(state, action)
|
802 |
+
|
803 |
+
# Get TD Target
|
804 |
+
td_tgt = self.td_target(reward, next_state, done, stepnumber)
|
805 |
+
|
806 |
+
# Backpropagate loss through Q_online
|
807 |
+
|
808 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
809 |
+
|
810 |
+
return (td_est.mean().item(), loss)
|
811 |
+
|
812 |
+
|
813 |
+
def save(self):
|
814 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
815 |
+
torch.save(
|
816 |
+
dict(
|
817 |
+
model=self.net.state_dict(),
|
818 |
+
exploration_rate=self.exploration_rate,
|
819 |
+
replay_memory=self.memory
|
820 |
+
),
|
821 |
+
save_path
|
822 |
+
)
|
823 |
+
|
824 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
825 |
+
|
826 |
+
|
827 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
828 |
+
if not load_path.exists():
|
829 |
+
raise ValueError(f"{load_path} does not exist")
|
830 |
+
|
831 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
832 |
+
exploration_rate = ckp.get('exploration_rate')
|
833 |
+
state_dict = ckp.get('model')
|
834 |
+
|
835 |
+
|
836 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
837 |
+
self.net.load_state_dict(state_dict)
|
838 |
+
|
839 |
+
if load_replay_buffer:
|
840 |
+
replay_memory = ckp.get('replay_memory')
|
841 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
842 |
+
self.memory = replay_memory if replay_memory else self.memory
|
843 |
+
|
844 |
+
if reset_exploration_rate:
|
845 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
846 |
+
else:
|
847 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
848 |
+
self.exploration_rate = exploration_rate
|
849 |
+
|
850 |
+
|
851 |
+
class DDQNAgentWithStepDecay(DQNAgentWithStepDecay):
|
852 |
+
@torch.no_grad()
|
853 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
854 |
+
rewards = rewards.reshape(-1, 1)
|
855 |
+
dones = dones.reshape(-1, 1)
|
856 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
857 |
+
q_vals = self.net(next_states, model='online')
|
858 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
859 |
+
target_actions = target_actions.reshape(-1, 1)
|
860 |
+
|
861 |
+
target_qs = self.net(next_states, model='target')
|
862 |
+
target_qs = target_qs.gather(1, target_actions)
|
863 |
+
target_qs = target_qs.reshape(-1, 1)
|
864 |
+
target_qs[dones] = 0.0
|
865 |
+
discount = ((200 - stepnumbers)/200)
|
866 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
867 |
+
return (rewards + val)
|
868 |
+
|
869 |
+
|
870 |
+
class DuelingDQNAgentWithStepDecay:
|
871 |
+
def __init__(self,
|
872 |
+
state_dim,
|
873 |
+
action_dim,
|
874 |
+
save_dir,
|
875 |
+
checkpoint=None,
|
876 |
+
learning_rate=0.00025,
|
877 |
+
max_memory_size=100000,
|
878 |
+
batch_size=32,
|
879 |
+
exploration_rate=1,
|
880 |
+
exploration_rate_decay=0.9999999,
|
881 |
+
exploration_rate_min=0.1,
|
882 |
+
training_frequency=1,
|
883 |
+
learning_starts=1000,
|
884 |
+
target_network_sync_frequency=500,
|
885 |
+
reset_exploration_rate=False,
|
886 |
+
save_frequency=100000,
|
887 |
+
gamma=0.9,
|
888 |
+
load_replay_buffer=True):
|
889 |
+
self.state_dim = state_dim
|
890 |
+
self.action_dim = action_dim
|
891 |
+
self.max_memory_size = max_memory_size
|
892 |
+
self.memory = deque(maxlen=max_memory_size)
|
893 |
+
self.batch_size = batch_size
|
894 |
+
|
895 |
+
self.exploration_rate = exploration_rate
|
896 |
+
self.exploration_rate_decay = exploration_rate_decay
|
897 |
+
self.exploration_rate_min = exploration_rate_min
|
898 |
+
self.gamma = gamma
|
899 |
+
|
900 |
+
self.curr_step = 0
|
901 |
+
self.learning_starts = learning_starts # min. experiences before training
|
902 |
+
|
903 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
904 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
905 |
+
|
906 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
907 |
+
self.save_dir = save_dir
|
908 |
+
|
909 |
+
self.use_cuda = torch.cuda.is_available()
|
910 |
+
|
911 |
+
|
912 |
+
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
|
913 |
+
self.target_net = copy.deepcopy(self.online_net)
|
914 |
+
# Q_target parameters are frozen.
|
915 |
+
for p in self.target_net.parameters():
|
916 |
+
p.requires_grad = False
|
917 |
+
|
918 |
+
if self.use_cuda:
|
919 |
+
self.online_net = self.online_net(device='cuda')
|
920 |
+
self.target_net = self.target_net(device='cuda')
|
921 |
+
if checkpoint:
|
922 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
923 |
+
|
924 |
+
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
|
925 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
926 |
+
# self.optimizer = torch.optim.Adam(self.online_net.parameters(), lr=learning_rate)
|
927 |
+
# self.loss_fn = torch.nn.MSELoss()
|
928 |
+
|
929 |
+
|
930 |
+
def act(self, state):
|
931 |
+
"""
|
932 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
933 |
+
|
934 |
+
Inputs:
|
935 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
936 |
+
Outputs:
|
937 |
+
action_idx (int): An integer representing which action the agent will perform
|
938 |
+
"""
|
939 |
+
# EXPLORE
|
940 |
+
if np.random.rand() < self.exploration_rate:
|
941 |
+
action_idx = np.random.randint(self.action_dim)
|
942 |
+
|
943 |
+
# EXPLOIT
|
944 |
+
else:
|
945 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
946 |
+
state = state.unsqueeze(0)
|
947 |
+
action_values = self.online_net(state)
|
948 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
949 |
+
|
950 |
+
# decrease exploration_rate
|
951 |
+
self.exploration_rate *= self.exploration_rate_decay
|
952 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
953 |
+
|
954 |
+
# increment step
|
955 |
+
self.curr_step += 1
|
956 |
+
return action_idx
|
957 |
+
|
958 |
+
def cache(self, state, next_state, action, reward, done, stepnumber):
|
959 |
+
"""
|
960 |
+
Store the experience to self.memory (replay buffer)
|
961 |
+
|
962 |
+
Inputs:
|
963 |
+
state (LazyFrame),
|
964 |
+
next_state (LazyFrame),
|
965 |
+
action (int),
|
966 |
+
reward (float),
|
967 |
+
done(bool))
|
968 |
+
"""
|
969 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
970 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
971 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
972 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
973 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
974 |
+
stepnumber = torch.LongTensor([stepnumber]).cuda() if self.use_cuda else torch.LongTensor([stepnumber])
|
975 |
+
|
976 |
+
self.memory.append( (state, next_state, action, reward, done, stepnumber) )
|
977 |
+
|
978 |
+
|
979 |
+
def recall(self):
|
980 |
+
"""
|
981 |
+
Retrieve a batch of experiences from memory
|
982 |
+
"""
|
983 |
+
batch = random.sample(self.memory, self.batch_size)
|
984 |
+
state, next_state, action, reward, done, stepnumber = map(torch.stack, zip(*batch))
|
985 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze(), stepnumber.squeeze()
|
986 |
+
|
987 |
+
|
988 |
+
def td_estimate(self, states, actions):
|
989 |
+
actions = actions.reshape(-1, 1)
|
990 |
+
predicted_qs = self.online_net(states)# Q_online(s,a)
|
991 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
992 |
+
return predicted_qs
|
993 |
+
|
994 |
+
|
995 |
+
@torch.no_grad()
|
996 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
997 |
+
rewards = rewards.reshape(-1, 1)
|
998 |
+
dones = dones.reshape(-1, 1)
|
999 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
1000 |
+
target_qs = self.target_net.forward(next_states)
|
1001 |
+
target_qs = torch.max(target_qs, dim=1).values
|
1002 |
+
target_qs = target_qs.reshape(-1, 1)
|
1003 |
+
target_qs[dones] = 0.0
|
1004 |
+
discount = ((200 - stepnumbers)/200)
|
1005 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
1006 |
+
return (rewards + val)
|
1007 |
+
|
1008 |
+
def update_Q_online(self, td_estimate, td_target) :
|
1009 |
+
loss = self.loss_fn(td_estimate.float(), td_target.float())
|
1010 |
+
self.optimizer.zero_grad()
|
1011 |
+
loss.backward()
|
1012 |
+
self.optimizer.step()
|
1013 |
+
return loss.item()
|
1014 |
+
|
1015 |
+
|
1016 |
+
def sync_Q_target(self):
|
1017 |
+
self.target_net.load_state_dict(self.online_net.state_dict())
|
1018 |
+
|
1019 |
+
|
1020 |
+
def learn(self):
|
1021 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
1022 |
+
self.sync_Q_target()
|
1023 |
+
|
1024 |
+
if self.curr_step % self.save_every == 0:
|
1025 |
+
self.save()
|
1026 |
+
|
1027 |
+
if self.curr_step < self.learning_starts:
|
1028 |
+
return None, None
|
1029 |
+
|
1030 |
+
if self.curr_step % self.training_frequency != 0:
|
1031 |
+
return None, None
|
1032 |
+
|
1033 |
+
# Sample from memory
|
1034 |
+
state, next_state, action, reward, done, stepnumbers = self.recall()
|
1035 |
+
|
1036 |
+
# Get TD Estimate
|
1037 |
+
td_est = self.td_estimate(state, action)
|
1038 |
+
|
1039 |
+
# Get TD Target
|
1040 |
+
td_tgt = self.td_target(reward, next_state, done, stepnumbers)
|
1041 |
+
|
1042 |
+
# Backpropagate loss through Q_online
|
1043 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
1044 |
+
|
1045 |
+
return (td_est.mean().item(), loss)
|
1046 |
+
|
1047 |
+
|
1048 |
+
def save(self):
|
1049 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
1050 |
+
torch.save(
|
1051 |
+
dict(
|
1052 |
+
model=self.online_net.state_dict(),
|
1053 |
+
exploration_rate=self.exploration_rate,
|
1054 |
+
replay_memory=self.memory
|
1055 |
+
),
|
1056 |
+
save_path
|
1057 |
+
)
|
1058 |
+
|
1059 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
1060 |
+
|
1061 |
+
|
1062 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
1063 |
+
if not load_path.exists():
|
1064 |
+
raise ValueError(f"{load_path} does not exist")
|
1065 |
+
|
1066 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
1067 |
+
exploration_rate = ckp.get('exploration_rate')
|
1068 |
+
state_dict = ckp.get('model')
|
1069 |
+
|
1070 |
+
|
1071 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
1072 |
+
self.online_net.load_state_dict(state_dict)
|
1073 |
+
self.target_net = copy.deepcopy(self.online_net)
|
1074 |
+
self.sync_Q_target()
|
1075 |
+
|
1076 |
+
if load_replay_buffer:
|
1077 |
+
replay_memory = ckp.get('replay_memory')
|
1078 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
1079 |
+
self.memory = replay_memory if replay_memory else self.memory
|
1080 |
+
|
1081 |
+
if reset_exploration_rate:
|
1082 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
1083 |
+
else:
|
1084 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
1085 |
+
self.exploration_rate = exploration_rate
|
1086 |
+
|
1087 |
+
|
1088 |
+
class DuelingDDQNAgentWithStepDecay(DuelingDQNAgentWithStepDecay):
|
1089 |
+
@torch.no_grad()
|
1090 |
+
def td_target(self, rewards, next_states, dones, stepnumbers):
|
1091 |
+
rewards = rewards.reshape(-1, 1)
|
1092 |
+
dones = dones.reshape(-1, 1)
|
1093 |
+
stepnumbers = stepnumbers.reshape(-1, 1)
|
1094 |
+
q_vals = self.online_net.forward(next_states)
|
1095 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
1096 |
+
target_actions = target_actions.reshape(-1, 1)
|
1097 |
+
|
1098 |
+
target_qs = self.target_net.forward(next_states)
|
1099 |
+
target_qs = target_qs.gather(1, target_actions)
|
1100 |
+
target_qs = target_qs.reshape(-1, 1)
|
1101 |
+
target_qs[dones] = 0.0
|
1102 |
+
discount = ((200 - stepnumbers)/200)
|
1103 |
+
val = np.minimum(discount, self.gamma * target_qs)
|
1104 |
+
return (rewards + val)
|
src/lunar-lander/params.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hyperparams = dict(
|
2 |
+
batch_size=128,
|
3 |
+
exploration_rate=1,
|
4 |
+
exploration_rate_decay=0.99999,
|
5 |
+
exploration_rate_min=0.01,
|
6 |
+
training_frequency=1,
|
7 |
+
target_network_sync_frequency=20,
|
8 |
+
max_memory_size=1000000,
|
9 |
+
learning_rate=0.001,
|
10 |
+
learning_starts=128,
|
11 |
+
save_frequency=100000
|
12 |
+
)
|
src/lunar-lander/replay.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
from pathlib import Path
|
3 |
+
from agent import DQNAgent, DDQNAgent, MetricLogger
|
4 |
+
from wrappers import make_lunar
|
5 |
+
|
6 |
+
|
7 |
+
env = make_lunar()
|
8 |
+
|
9 |
+
env.reset()
|
10 |
+
|
11 |
+
save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
12 |
+
save_dir.mkdir(parents=True)
|
13 |
+
|
14 |
+
# checkpoint = Path('checkpoints/lunar-lander-dueling-ddqn/airstriker_net_2.chkpt')
|
15 |
+
checkpoint = Path('checkpoints/lunar-lander-dqn-rc/airstriker_net_1.chkpt')
|
16 |
+
|
17 |
+
logger = MetricLogger(save_dir)
|
18 |
+
|
19 |
+
print("Testing Double DQN Agent!")
|
20 |
+
agent = DDQNAgent(
|
21 |
+
state_dim=8,
|
22 |
+
action_dim=env.action_space.n,
|
23 |
+
save_dir=save_dir,
|
24 |
+
batch_size=512,
|
25 |
+
checkpoint=checkpoint,
|
26 |
+
exploration_rate_decay=0.999995,
|
27 |
+
exploration_rate_min=0.05,
|
28 |
+
training_frequency=1,
|
29 |
+
target_network_sync_frequency=200,
|
30 |
+
max_memory_size=50000,
|
31 |
+
learning_rate=0.0005,
|
32 |
+
load_replay_buffer=False
|
33 |
+
|
34 |
+
)
|
35 |
+
agent.exploration_rate = agent.exploration_rate_min
|
36 |
+
|
37 |
+
episodes = 100
|
38 |
+
|
39 |
+
for e in range(episodes):
|
40 |
+
|
41 |
+
state = env.reset()
|
42 |
+
|
43 |
+
while True:
|
44 |
+
|
45 |
+
env.render()
|
46 |
+
|
47 |
+
action = agent.act(state)
|
48 |
+
|
49 |
+
next_state, reward, done, info = env.step(action)
|
50 |
+
|
51 |
+
# agent.cache(state, next_state, action, reward, done)
|
52 |
+
|
53 |
+
# logger.log_step(reward, None, None)
|
54 |
+
|
55 |
+
state = next_state
|
56 |
+
|
57 |
+
if done:
|
58 |
+
break
|
59 |
+
|
60 |
+
# logger.log_episode()
|
61 |
+
|
62 |
+
# if e % 20 == 0:
|
63 |
+
# logger.record(
|
64 |
+
# episode=e,
|
65 |
+
# epsilon=agent.exploration_rate,
|
66 |
+
# step=agent.curr_step
|
67 |
+
# )
|
src/lunar-lander/run-lunar-ddqn.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DDQNAgent, DDQNAgentWithStepDecay, MetricLogger
|
6 |
+
from wrappers import make_lunar
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
from params import hyperparams
|
10 |
+
|
11 |
+
env = make_lunar()
|
12 |
+
|
13 |
+
use_cuda = torch.cuda.is_available()
|
14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
15 |
+
|
16 |
+
checkpoint = None
|
17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
18 |
+
|
19 |
+
path = "checkpoints/lunar-lander-ddqn-rc"
|
20 |
+
save_dir = Path(path)
|
21 |
+
|
22 |
+
isExist = os.path.exists(path)
|
23 |
+
if not isExist:
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
logger = MetricLogger(save_dir)
|
27 |
+
|
28 |
+
print("Training DDQN Agent!")
|
29 |
+
agent = DDQNAgentWithStepDecay(
|
30 |
+
state_dim=8,
|
31 |
+
action_dim=env.action_space.n,
|
32 |
+
save_dir=save_dir,
|
33 |
+
checkpoint=checkpoint,
|
34 |
+
**hyperparams
|
35 |
+
)
|
36 |
+
# agent = DDQNAgent(
|
37 |
+
# state_dim=8,
|
38 |
+
# action_dim=env.action_space.n,
|
39 |
+
# save_dir=save_dir,
|
40 |
+
# checkpoint=checkpoint,
|
41 |
+
# **hyperparams
|
42 |
+
# )
|
43 |
+
|
44 |
+
# fill_memory(agent, env, 5000)
|
45 |
+
train(agent, env, logger)
|
src/lunar-lander/run-lunar-dqn.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DQNAgent, DQNAgentWithStepDecay, MetricLogger
|
6 |
+
from wrappers import make_lunar
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
from params import hyperparams
|
10 |
+
|
11 |
+
env = make_lunar()
|
12 |
+
|
13 |
+
use_cuda = torch.cuda.is_available()
|
14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
15 |
+
|
16 |
+
checkpoint = None
|
17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
18 |
+
|
19 |
+
path = "checkpoints/lunar-lander-dqn-rc"
|
20 |
+
save_dir = Path(path)
|
21 |
+
|
22 |
+
isExist = os.path.exists(path)
|
23 |
+
if not isExist:
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
logger = MetricLogger(save_dir)
|
27 |
+
|
28 |
+
print("Training Vanilla DQN Agent with decay!")
|
29 |
+
agent = DQNAgentWithStepDecay(
|
30 |
+
state_dim=8,
|
31 |
+
action_dim=env.action_space.n,
|
32 |
+
save_dir=save_dir,
|
33 |
+
checkpoint=checkpoint,
|
34 |
+
**hyperparams
|
35 |
+
)
|
36 |
+
# print("Training Vanilla DQN Agent!")
|
37 |
+
# agent = DQNAgent(
|
38 |
+
# state_dim=8,
|
39 |
+
# action_dim=env.action_space.n,
|
40 |
+
# save_dir=save_dir,
|
41 |
+
# checkpoint=checkpoint,
|
42 |
+
# **hyperparams
|
43 |
+
# )
|
44 |
+
|
45 |
+
# fill_memory(agent, env, 5000)
|
46 |
+
train(agent, env, logger)
|
src/lunar-lander/run-lunar-dueling-ddqn.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DuelingDDQNAgent, DuelingDDQNAgentWithStepDecay,MetricLogger
|
6 |
+
from wrappers import make_lunar
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
from params import hyperparams
|
10 |
+
|
11 |
+
|
12 |
+
env = make_lunar()
|
13 |
+
|
14 |
+
use_cuda = torch.cuda.is_available()
|
15 |
+
print(f"Using CUDA: {use_cuda}\n")
|
16 |
+
|
17 |
+
checkpoint = None
|
18 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
19 |
+
|
20 |
+
path = "checkpoints/lunar-lander-dueling-ddqn-rc"
|
21 |
+
save_dir = Path(path)
|
22 |
+
|
23 |
+
isExist = os.path.exists(path)
|
24 |
+
if not isExist:
|
25 |
+
os.makedirs(path)
|
26 |
+
|
27 |
+
logger = MetricLogger(save_dir)
|
28 |
+
|
29 |
+
print("Training Dueling DDQN Agent with step decay!")
|
30 |
+
agent = DuelingDDQNAgentWithStepDecay(
|
31 |
+
state_dim=8,
|
32 |
+
action_dim=env.action_space.n,
|
33 |
+
save_dir=save_dir,
|
34 |
+
checkpoint=checkpoint,
|
35 |
+
**hyperparams
|
36 |
+
)
|
37 |
+
# print("Training Dueling DDQN Agent!")
|
38 |
+
# agent = DuelingDDQNAgent(
|
39 |
+
# state_dim=8,
|
40 |
+
# action_dim=env.action_space.n,
|
41 |
+
# save_dir=save_dir,
|
42 |
+
# checkpoint=checkpoint,
|
43 |
+
# **hyperparams
|
44 |
+
# )
|
45 |
+
|
46 |
+
# fill_memory(agent, env, 5000)
|
47 |
+
train(agent, env, logger)
|
src/lunar-lander/run-lunar-dueling-dqn.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DuelingDQNAgent, DuelingDQNAgentWithStepDecay, MetricLogger
|
6 |
+
from wrappers import make_lunar
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
from params import hyperparams
|
10 |
+
|
11 |
+
env = make_lunar()
|
12 |
+
|
13 |
+
use_cuda = torch.cuda.is_available()
|
14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
15 |
+
|
16 |
+
checkpoint = None
|
17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
18 |
+
|
19 |
+
path = "checkpoints/lunar-lander-dueling-dqn-rc"
|
20 |
+
save_dir = Path(path)
|
21 |
+
|
22 |
+
isExist = os.path.exists(path)
|
23 |
+
if not isExist:
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
logger = MetricLogger(save_dir)
|
27 |
+
|
28 |
+
print("Training Dueling DQN Agent with step decay!")
|
29 |
+
agent = DuelingDQNAgentWithStepDecay(
|
30 |
+
state_dim=8,
|
31 |
+
action_dim=env.action_space.n,
|
32 |
+
save_dir=save_dir,
|
33 |
+
checkpoint=checkpoint,
|
34 |
+
**hyperparams
|
35 |
+
)
|
36 |
+
# print("Training Dueling DQN Agent!")
|
37 |
+
# agent = DuelingDQNAgent(
|
38 |
+
# state_dim=8,
|
39 |
+
# action_dim=env.action_space.n,
|
40 |
+
# save_dir=save_dir,
|
41 |
+
# checkpoint=checkpoint,
|
42 |
+
# **hyperparams
|
43 |
+
# )
|
44 |
+
|
45 |
+
# fill_memory(agent, env, 5000)
|
46 |
+
train(agent, env, logger)
|
src/lunar-lander/train.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import trange
|
2 |
+
|
3 |
+
def fill_memory(agent, env, num_episodes=500 ):
|
4 |
+
print("Filling up memory....")
|
5 |
+
for _ in trange(500):
|
6 |
+
state = env.reset()
|
7 |
+
done = False
|
8 |
+
while not done:
|
9 |
+
action = agent.act(state)
|
10 |
+
next_state, reward, done, _ = env.step(action)
|
11 |
+
agent.cache(state, next_state, action, reward, done)
|
12 |
+
state = next_state
|
13 |
+
|
14 |
+
|
15 |
+
# def train(agent, env, logger):
|
16 |
+
# episodes = 5000
|
17 |
+
# for e in range(episodes):
|
18 |
+
|
19 |
+
# state = env.reset()
|
20 |
+
# # Play the game!
|
21 |
+
# while True:
|
22 |
+
|
23 |
+
# # Run agent on the state
|
24 |
+
# action = agent.act(state)
|
25 |
+
|
26 |
+
# # Agent performs action
|
27 |
+
# next_state, reward, done, info = env.step(action)
|
28 |
+
|
29 |
+
# # Remember
|
30 |
+
# agent.cache(state, next_state, action, reward, done)
|
31 |
+
|
32 |
+
# # Learn
|
33 |
+
# q, loss = agent.learn()
|
34 |
+
|
35 |
+
# # Logging
|
36 |
+
# logger.log_step(reward, loss, q)
|
37 |
+
|
38 |
+
# # Update state
|
39 |
+
# state = next_state
|
40 |
+
|
41 |
+
# # Check if end of game
|
42 |
+
# if done:
|
43 |
+
# break
|
44 |
+
|
45 |
+
# logger.log_episode(e)
|
46 |
+
|
47 |
+
# if e % 20 == 0:
|
48 |
+
# logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
49 |
+
|
50 |
+
|
51 |
+
def train(agent, env, logger):
|
52 |
+
episodes = 5000
|
53 |
+
for e in range(episodes):
|
54 |
+
|
55 |
+
state = env.reset()
|
56 |
+
# Play the game!
|
57 |
+
for i in range(1000):
|
58 |
+
|
59 |
+
# Run agent on the state
|
60 |
+
action = agent.act(state)
|
61 |
+
env.render()
|
62 |
+
# Agent performs action
|
63 |
+
next_state, reward, done, info = env.step(action)
|
64 |
+
|
65 |
+
# Remember
|
66 |
+
agent.cache(state, next_state, action, reward, done, i)
|
67 |
+
|
68 |
+
# Learn
|
69 |
+
q, loss = agent.learn()
|
70 |
+
|
71 |
+
# Logging
|
72 |
+
logger.log_step(reward, loss, q)
|
73 |
+
|
74 |
+
# Update state
|
75 |
+
state = next_state
|
76 |
+
|
77 |
+
# Check if end of game
|
78 |
+
if done:
|
79 |
+
break
|
80 |
+
|
81 |
+
logger.log_episode(e)
|
82 |
+
|
83 |
+
if e % 20 == 0:
|
84 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
src/lunar-lander/wrappers.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
from collections import deque
|
4 |
+
import gym
|
5 |
+
from gym import spaces
|
6 |
+
import cv2
|
7 |
+
import math
|
8 |
+
|
9 |
+
'''
|
10 |
+
Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
11 |
+
'''
|
12 |
+
|
13 |
+
|
14 |
+
class LazyFrames(object):
|
15 |
+
def __init__(self, frames):
|
16 |
+
"""This object ensures that common frames between the observations are only stored once.
|
17 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
18 |
+
buffers.
|
19 |
+
This object should only be converted to numpy array before being passed to the model.
|
20 |
+
You'd not believe how complex the previous solution was."""
|
21 |
+
self._frames = frames
|
22 |
+
self._out = None
|
23 |
+
|
24 |
+
def _force(self):
|
25 |
+
if self._out is None:
|
26 |
+
self._out = np.concatenate(self._frames, axis=2)
|
27 |
+
self._frames = None
|
28 |
+
return self._out
|
29 |
+
|
30 |
+
def __array__(self, dtype=None):
|
31 |
+
out = self._force()
|
32 |
+
if dtype is not None:
|
33 |
+
out = out.astype(dtype)
|
34 |
+
return out
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self._force())
|
38 |
+
|
39 |
+
def __getitem__(self, i):
|
40 |
+
return self._force()[i]
|
41 |
+
|
42 |
+
class FireResetEnv(gym.Wrapper):
|
43 |
+
def __init__(self, env):
|
44 |
+
"""Take action on reset for environments that are fixed until firing."""
|
45 |
+
gym.Wrapper.__init__(self, env)
|
46 |
+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
47 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
48 |
+
|
49 |
+
def reset(self, **kwargs):
|
50 |
+
self.env.reset(**kwargs)
|
51 |
+
obs, _, done, _ = self.env.step(1)
|
52 |
+
if done:
|
53 |
+
self.env.reset(**kwargs)
|
54 |
+
obs, _, done, _ = self.env.step(2)
|
55 |
+
if done:
|
56 |
+
self.env.reset(**kwargs)
|
57 |
+
return obs
|
58 |
+
|
59 |
+
def step(self, ac):
|
60 |
+
return self.env.step(ac)
|
61 |
+
|
62 |
+
|
63 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
64 |
+
def __init__(self, env, skip=4):
|
65 |
+
"""Return only every `skip`-th frame"""
|
66 |
+
gym.Wrapper.__init__(self, env)
|
67 |
+
# most recent raw observations (for max pooling across time steps)
|
68 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
69 |
+
self._skip = skip
|
70 |
+
|
71 |
+
def step(self, action):
|
72 |
+
"""Repeat action, sum reward, and max over last observations."""
|
73 |
+
total_reward = 0.0
|
74 |
+
done = None
|
75 |
+
for i in range(self._skip):
|
76 |
+
obs, reward, done, info = self.env.step(action)
|
77 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
78 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
79 |
+
total_reward += reward
|
80 |
+
if done:
|
81 |
+
break
|
82 |
+
# Note that the observation on the done=True frame
|
83 |
+
# doesn't matter
|
84 |
+
max_frame = self._obs_buffer.max(axis=0)
|
85 |
+
|
86 |
+
return max_frame, total_reward, done, info
|
87 |
+
|
88 |
+
def reset(self, **kwargs):
|
89 |
+
return self.env.reset(**kwargs)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
class WarpFrame(gym.ObservationWrapper):
|
94 |
+
def __init__(self, env):
|
95 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
96 |
+
gym.ObservationWrapper.__init__(self, env)
|
97 |
+
self.width = 84
|
98 |
+
self.height = 84
|
99 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
100 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
101 |
+
|
102 |
+
def observation(self, frame):
|
103 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
104 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
105 |
+
return frame[:, :, None]
|
106 |
+
|
107 |
+
class WarpFrameNoResize(gym.ObservationWrapper):
|
108 |
+
def __init__(self, env):
|
109 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
110 |
+
gym.ObservationWrapper.__init__(self, env)
|
111 |
+
|
112 |
+
def observation(self, frame):
|
113 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
114 |
+
# frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
115 |
+
return frame[:, :, None]
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
class FrameStack(gym.Wrapper):
|
120 |
+
def __init__(self, env, k):
|
121 |
+
"""Stack k last frames.
|
122 |
+
Returns lazy array, which is much more memory efficient.
|
123 |
+
See Also
|
124 |
+
--------
|
125 |
+
baselines.common.atari_wrappers.LazyFrames
|
126 |
+
"""
|
127 |
+
gym.Wrapper.__init__(self, env)
|
128 |
+
self.k = k
|
129 |
+
self.frames = deque([], maxlen=k)
|
130 |
+
shp = env.observation_space.shape
|
131 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
132 |
+
|
133 |
+
def reset(self):
|
134 |
+
ob = self.env.reset()
|
135 |
+
for _ in range(self.k):
|
136 |
+
self.frames.append(ob)
|
137 |
+
return self._get_ob()
|
138 |
+
|
139 |
+
def step(self, action):
|
140 |
+
ob, reward, done, info = self.env.step(action)
|
141 |
+
self.frames.append(ob)
|
142 |
+
return self._get_ob(), reward, done, info
|
143 |
+
|
144 |
+
def _get_ob(self):
|
145 |
+
assert len(self.frames) == self.k
|
146 |
+
return LazyFrames(list(self.frames))
|
147 |
+
|
148 |
+
|
149 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
150 |
+
def __init__(self, env):
|
151 |
+
super(ImageToPyTorch, self).__init__(env)
|
152 |
+
old_shape = self.observation_space.shape
|
153 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
154 |
+
|
155 |
+
def observation(self, observation):
|
156 |
+
return np.moveaxis(observation, 2, 0)
|
157 |
+
|
158 |
+
|
159 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
160 |
+
def __init__(self, env):
|
161 |
+
gym.ObservationWrapper.__init__(self, env)
|
162 |
+
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
163 |
+
|
164 |
+
def observation(self, observation):
|
165 |
+
# careful! This undoes the memory optimization, use
|
166 |
+
# with smaller replay buffers only.
|
167 |
+
return np.array(observation).astype(np.float32) / 255.0
|
168 |
+
|
169 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
170 |
+
def __init__(self, env):
|
171 |
+
gym.RewardWrapper.__init__(self, env)
|
172 |
+
|
173 |
+
def reward(self, reward):
|
174 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
175 |
+
return np.sign(reward)
|
176 |
+
|
177 |
+
class TanRewardClipperEnv(gym.RewardWrapper):
|
178 |
+
def __init__(self, env):
|
179 |
+
gym.RewardWrapper.__init__(self, env)
|
180 |
+
|
181 |
+
def reward(self, reward):
|
182 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
183 |
+
return 10 * math.tanh(float(reward)/30.)
|
184 |
+
|
185 |
+
|
186 |
+
def make_lunar(render=False):
|
187 |
+
print("Environment: Lunar Lander")
|
188 |
+
env = gym.make("LunarLander-v2")
|
189 |
+
# env = TanRewardClipperEnv(env)
|
190 |
+
# env = WarpFrameNoResize(env) ## Reshape image
|
191 |
+
# env = ImageToPyTorch(env) ## Invert shape
|
192 |
+
# env = FrameStack(env, 4) ## Stack last 4 frames
|
193 |
+
return env
|
src/procgen/agent.py
ADDED
@@ -0,0 +1,664 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
import torch.nn as nn
|
5 |
+
import copy
|
6 |
+
import time, datetime
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from collections import deque
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
|
11 |
+
|
12 |
+
class DQNet(nn.Module):
|
13 |
+
"""mini cnn structure
|
14 |
+
input -> (conv2d + relu) x 3 -> flatten -> (dense + relu) x 2 -> output
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, input_dim, output_dim):
|
18 |
+
super().__init__()
|
19 |
+
print("#################################")
|
20 |
+
print("#################################")
|
21 |
+
print(input_dim)
|
22 |
+
print(output_dim)
|
23 |
+
print("#################################")
|
24 |
+
print("#################################")
|
25 |
+
c, h, w = input_dim
|
26 |
+
|
27 |
+
|
28 |
+
self.online = nn.Sequential(
|
29 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
30 |
+
nn.ReLU(),
|
31 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
34 |
+
nn.ReLU(),
|
35 |
+
nn.Flatten(),
|
36 |
+
nn.Linear(7168, 512),
|
37 |
+
nn.ReLU(),
|
38 |
+
nn.Linear(512, output_dim),
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
self.target = copy.deepcopy(self.online)
|
43 |
+
|
44 |
+
# Q_target parameters are frozen.
|
45 |
+
for p in self.target.parameters():
|
46 |
+
p.requires_grad = False
|
47 |
+
|
48 |
+
def forward(self, input, model):
|
49 |
+
if model == "online":
|
50 |
+
return self.online(input)
|
51 |
+
elif model == "target":
|
52 |
+
return self.target(input)
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
class MetricLogger:
|
57 |
+
def __init__(self, save_dir):
|
58 |
+
self.writer = SummaryWriter(log_dir=save_dir)
|
59 |
+
self.save_log = save_dir / "log"
|
60 |
+
with open(self.save_log, "w") as f:
|
61 |
+
f.write(
|
62 |
+
f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
|
63 |
+
f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
|
64 |
+
f"{'TimeDelta':>15}{'Time':>20}\n"
|
65 |
+
)
|
66 |
+
self.ep_rewards_plot = save_dir / "reward_plot.jpg"
|
67 |
+
self.ep_lengths_plot = save_dir / "length_plot.jpg"
|
68 |
+
self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
|
69 |
+
self.ep_avg_qs_plot = save_dir / "q_plot.jpg"
|
70 |
+
|
71 |
+
# History metrics
|
72 |
+
self.ep_rewards = []
|
73 |
+
self.ep_lengths = []
|
74 |
+
self.ep_avg_losses = []
|
75 |
+
self.ep_avg_qs = []
|
76 |
+
|
77 |
+
# Moving averages, added for every call to record()
|
78 |
+
self.moving_avg_ep_rewards = []
|
79 |
+
self.moving_avg_ep_lengths = []
|
80 |
+
self.moving_avg_ep_avg_losses = []
|
81 |
+
self.moving_avg_ep_avg_qs = []
|
82 |
+
|
83 |
+
# Current episode metric
|
84 |
+
self.init_episode()
|
85 |
+
|
86 |
+
# Timing
|
87 |
+
self.record_time = time.time()
|
88 |
+
|
89 |
+
def log_step(self, reward, loss, q):
|
90 |
+
self.curr_ep_reward += reward
|
91 |
+
self.curr_ep_length += 1
|
92 |
+
if loss:
|
93 |
+
self.curr_ep_loss += loss
|
94 |
+
self.curr_ep_q += q
|
95 |
+
self.curr_ep_loss_length += 1
|
96 |
+
|
97 |
+
def log_episode(self, episode_number):
|
98 |
+
"Mark end of episode"
|
99 |
+
self.ep_rewards.append(self.curr_ep_reward)
|
100 |
+
self.ep_lengths.append(self.curr_ep_length)
|
101 |
+
if self.curr_ep_loss_length == 0:
|
102 |
+
ep_avg_loss = 0
|
103 |
+
ep_avg_q = 0
|
104 |
+
else:
|
105 |
+
ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
|
106 |
+
ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
|
107 |
+
self.ep_avg_losses.append(ep_avg_loss)
|
108 |
+
self.ep_avg_qs.append(ep_avg_q)
|
109 |
+
self.writer.add_scalar("Avg Loss for episode", ep_avg_loss, episode_number)
|
110 |
+
self.writer.add_scalar("Avg Q value for episode", ep_avg_q, episode_number)
|
111 |
+
self.writer.flush()
|
112 |
+
self.init_episode()
|
113 |
+
|
114 |
+
def init_episode(self):
|
115 |
+
self.curr_ep_reward = 0.0
|
116 |
+
self.curr_ep_length = 0
|
117 |
+
self.curr_ep_loss = 0.0
|
118 |
+
self.curr_ep_q = 0.0
|
119 |
+
self.curr_ep_loss_length = 0
|
120 |
+
|
121 |
+
def record(self, episode, epsilon, step):
|
122 |
+
mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
|
123 |
+
mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
|
124 |
+
mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
|
125 |
+
mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
|
126 |
+
self.moving_avg_ep_rewards.append(mean_ep_reward)
|
127 |
+
self.moving_avg_ep_lengths.append(mean_ep_length)
|
128 |
+
self.moving_avg_ep_avg_losses.append(mean_ep_loss)
|
129 |
+
self.moving_avg_ep_avg_qs.append(mean_ep_q)
|
130 |
+
|
131 |
+
last_record_time = self.record_time
|
132 |
+
self.record_time = time.time()
|
133 |
+
time_since_last_record = np.round(self.record_time - last_record_time, 3)
|
134 |
+
|
135 |
+
print(
|
136 |
+
f"Episode {episode} - "
|
137 |
+
f"Step {step} - "
|
138 |
+
f"Epsilon {epsilon} - "
|
139 |
+
f"Mean Reward {mean_ep_reward} - "
|
140 |
+
f"Mean Length {mean_ep_length} - "
|
141 |
+
f"Mean Loss {mean_ep_loss} - "
|
142 |
+
f"Mean Q Value {mean_ep_q} - "
|
143 |
+
f"Time Delta {time_since_last_record} - "
|
144 |
+
f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
|
145 |
+
)
|
146 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
147 |
+
self.writer.add_scalar("Mean length last 100 episodes", mean_ep_length, episode)
|
148 |
+
self.writer.add_scalar("Mean loss last 100 episodes", mean_ep_loss, episode)
|
149 |
+
self.writer.add_scalar("Mean reward last 100 episodes", mean_ep_reward, episode)
|
150 |
+
self.writer.add_scalar("Epsilon value", epsilon, episode)
|
151 |
+
self.writer.add_scalar("Mean Q Value last 100 episodes", mean_ep_q, episode)
|
152 |
+
self.writer.flush()
|
153 |
+
with open(self.save_log, "a") as f:
|
154 |
+
f.write(
|
155 |
+
f"{episode:8d}{step:8d}{epsilon:10.3f}"
|
156 |
+
f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
|
157 |
+
f"{time_since_last_record:15.3f}"
|
158 |
+
f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
|
159 |
+
)
|
160 |
+
|
161 |
+
for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
|
162 |
+
plt.plot(getattr(self, f"moving_avg_{metric}"))
|
163 |
+
plt.savefig(getattr(self, f"{metric}_plot"))
|
164 |
+
plt.clf()
|
165 |
+
|
166 |
+
|
167 |
+
class DQNAgent:
|
168 |
+
def __init__(self,
|
169 |
+
state_dim,
|
170 |
+
action_dim,
|
171 |
+
save_dir,
|
172 |
+
checkpoint=None,
|
173 |
+
learning_rate=0.00025,
|
174 |
+
max_memory_size=100000,
|
175 |
+
batch_size=32,
|
176 |
+
exploration_rate=1,
|
177 |
+
exploration_rate_decay=0.9999999,
|
178 |
+
exploration_rate_min=0.1,
|
179 |
+
training_frequency=1,
|
180 |
+
learning_starts=1000,
|
181 |
+
target_network_sync_frequency=500,
|
182 |
+
reset_exploration_rate=False,
|
183 |
+
save_frequency=100000,
|
184 |
+
gamma=0.9,
|
185 |
+
load_replay_buffer=True):
|
186 |
+
self.state_dim = state_dim
|
187 |
+
self.action_dim = action_dim
|
188 |
+
self.max_memory_size = max_memory_size
|
189 |
+
self.memory = deque(maxlen=max_memory_size)
|
190 |
+
self.batch_size = batch_size
|
191 |
+
|
192 |
+
self.exploration_rate = exploration_rate
|
193 |
+
self.exploration_rate_decay = exploration_rate_decay
|
194 |
+
self.exploration_rate_min = exploration_rate_min
|
195 |
+
self.gamma = gamma
|
196 |
+
|
197 |
+
self.curr_step = 0
|
198 |
+
self.learning_starts = learning_starts # min. experiences before training
|
199 |
+
|
200 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
201 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
202 |
+
|
203 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
204 |
+
self.save_dir = save_dir
|
205 |
+
|
206 |
+
self.use_cuda = torch.cuda.is_available()
|
207 |
+
|
208 |
+
self.net = DQNet(self.state_dim, self.action_dim).float()
|
209 |
+
if self.use_cuda:
|
210 |
+
self.net = self.net.to(device='cuda')
|
211 |
+
if checkpoint:
|
212 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
213 |
+
|
214 |
+
self.optimizer = torch.optim.AdamW(self.net.parameters(), lr=learning_rate, amsgrad=True)
|
215 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
216 |
+
|
217 |
+
|
218 |
+
def act(self, state):
|
219 |
+
"""
|
220 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
221 |
+
|
222 |
+
Inputs:
|
223 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
224 |
+
Outputs:
|
225 |
+
action_idx (int): An integer representing which action the agent will perform
|
226 |
+
"""
|
227 |
+
# EXPLORE
|
228 |
+
if np.random.rand() < self.exploration_rate:
|
229 |
+
action_idx = np.random.randint(self.action_dim)
|
230 |
+
|
231 |
+
# EXPLOIT
|
232 |
+
else:
|
233 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
234 |
+
state = state.unsqueeze(0)
|
235 |
+
action_values = self.net(state, model='online')
|
236 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
237 |
+
|
238 |
+
# decrease exploration_rate
|
239 |
+
self.exploration_rate *= self.exploration_rate_decay
|
240 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
241 |
+
|
242 |
+
# increment step
|
243 |
+
self.curr_step += 1
|
244 |
+
return action_idx
|
245 |
+
|
246 |
+
def cache(self, state, next_state, action, reward, done):
|
247 |
+
"""
|
248 |
+
Store the experience to self.memory (replay buffer)
|
249 |
+
|
250 |
+
Inputs:
|
251 |
+
state (LazyFrame),
|
252 |
+
next_state (LazyFrame),
|
253 |
+
action (int),
|
254 |
+
reward (float),
|
255 |
+
done(bool))
|
256 |
+
"""
|
257 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
258 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
259 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
260 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
261 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
262 |
+
|
263 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
264 |
+
|
265 |
+
|
266 |
+
def recall(self):
|
267 |
+
"""
|
268 |
+
Retrieve a batch of experiences from memory
|
269 |
+
"""
|
270 |
+
batch = random.sample(self.memory, self.batch_size)
|
271 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
272 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
273 |
+
|
274 |
+
|
275 |
+
def td_estimate(self, states, actions):
|
276 |
+
actions = actions.reshape(-1, 1)
|
277 |
+
predicted_qs = self.net(states, model='online')# Q_online(s,a)
|
278 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
279 |
+
return predicted_qs
|
280 |
+
|
281 |
+
|
282 |
+
@torch.no_grad()
|
283 |
+
def td_target(self, rewards, next_states, dones):
|
284 |
+
rewards = rewards.reshape(-1, 1)
|
285 |
+
dones = dones.reshape(-1, 1)
|
286 |
+
target_qs = self.net(next_states, model='target')
|
287 |
+
target_qs = torch.max(target_qs, dim=1).values
|
288 |
+
target_qs = target_qs.reshape(-1, 1)
|
289 |
+
target_qs[dones] = 0.0
|
290 |
+
return (rewards + (self.gamma * target_qs))
|
291 |
+
|
292 |
+
def update_Q_online(self, td_estimate, td_target) :
|
293 |
+
loss = self.loss_fn(td_estimate, td_target)
|
294 |
+
self.optimizer.zero_grad()
|
295 |
+
loss.backward()
|
296 |
+
self.optimizer.step()
|
297 |
+
return loss.item()
|
298 |
+
|
299 |
+
|
300 |
+
def sync_Q_target(self):
|
301 |
+
self.net.target.load_state_dict(self.net.online.state_dict())
|
302 |
+
|
303 |
+
|
304 |
+
def learn(self):
|
305 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
306 |
+
self.sync_Q_target()
|
307 |
+
|
308 |
+
if self.curr_step % self.save_every == 0:
|
309 |
+
self.save()
|
310 |
+
|
311 |
+
if self.curr_step < self.learning_starts:
|
312 |
+
return None, None
|
313 |
+
|
314 |
+
if self.curr_step % self.training_frequency != 0:
|
315 |
+
return None, None
|
316 |
+
|
317 |
+
# Sample from memory
|
318 |
+
state, next_state, action, reward, done = self.recall()
|
319 |
+
|
320 |
+
# Get TD Estimate
|
321 |
+
td_est = self.td_estimate(state, action)
|
322 |
+
|
323 |
+
# Get TD Target
|
324 |
+
td_tgt = self.td_target(reward, next_state, done)
|
325 |
+
|
326 |
+
# Backpropagate loss through Q_online
|
327 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
328 |
+
|
329 |
+
return (td_est.mean().item(), loss)
|
330 |
+
|
331 |
+
|
332 |
+
def save(self):
|
333 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
334 |
+
torch.save(
|
335 |
+
dict(
|
336 |
+
model=self.net.state_dict(),
|
337 |
+
exploration_rate=self.exploration_rate,
|
338 |
+
replay_memory=self.memory
|
339 |
+
),
|
340 |
+
save_path
|
341 |
+
)
|
342 |
+
|
343 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
344 |
+
|
345 |
+
|
346 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
347 |
+
if not load_path.exists():
|
348 |
+
raise ValueError(f"{load_path} does not exist")
|
349 |
+
|
350 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
351 |
+
exploration_rate = ckp.get('exploration_rate')
|
352 |
+
state_dict = ckp.get('model')
|
353 |
+
|
354 |
+
|
355 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
356 |
+
self.net.load_state_dict(state_dict)
|
357 |
+
|
358 |
+
if load_replay_buffer:
|
359 |
+
replay_memory = ckp.get('replay_memory')
|
360 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
361 |
+
self.memory = replay_memory if replay_memory else self.memory
|
362 |
+
|
363 |
+
if reset_exploration_rate:
|
364 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
365 |
+
else:
|
366 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
367 |
+
self.exploration_rate = exploration_rate
|
368 |
+
|
369 |
+
|
370 |
+
class DDQNAgent(DQNAgent):
|
371 |
+
@torch.no_grad()
|
372 |
+
def td_target(self, rewards, next_states, dones):
|
373 |
+
rewards = rewards.reshape(-1, 1)
|
374 |
+
dones = dones.reshape(-1, 1)
|
375 |
+
q_vals = self.net(next_states, model='online')
|
376 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
377 |
+
target_actions = target_actions.reshape(-1, 1)
|
378 |
+
|
379 |
+
target_qs = self.net(next_states, model='target')
|
380 |
+
target_qs = target_qs.gather(1, target_actions)
|
381 |
+
target_qs = target_qs.reshape(-1, 1)
|
382 |
+
target_qs[dones] = 0.0
|
383 |
+
return (rewards + (self.gamma * target_qs))
|
384 |
+
|
385 |
+
|
386 |
+
class DuelingDQNet(nn.Module):
|
387 |
+
def __init__(self, input_dim, output_dim):
|
388 |
+
super().__init__()
|
389 |
+
print("#################################")
|
390 |
+
print("#################################")
|
391 |
+
print(input_dim)
|
392 |
+
print(output_dim)
|
393 |
+
print("#################################")
|
394 |
+
print("#################################")
|
395 |
+
c, h, w = input_dim
|
396 |
+
|
397 |
+
|
398 |
+
self.conv_layer = nn.Sequential(
|
399 |
+
nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
|
400 |
+
nn.ReLU(),
|
401 |
+
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
|
402 |
+
nn.ReLU(),
|
403 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
|
404 |
+
nn.ReLU(),
|
405 |
+
|
406 |
+
)
|
407 |
+
|
408 |
+
|
409 |
+
self.value_layer = nn.Sequential(
|
410 |
+
nn.Linear(7168, 128),
|
411 |
+
nn.ReLU(),
|
412 |
+
nn.Linear(128, 1)
|
413 |
+
)
|
414 |
+
|
415 |
+
self.advantage_layer = nn.Sequential(
|
416 |
+
nn.Linear(7168, 128),
|
417 |
+
nn.ReLU(),
|
418 |
+
nn.Linear(128, output_dim)
|
419 |
+
)
|
420 |
+
|
421 |
+
def forward(self, state):
|
422 |
+
conv_output = self.conv_layer(state)
|
423 |
+
conv_output = conv_output.view(conv_output.size(0), -1)
|
424 |
+
value = self.value_layer(conv_output)
|
425 |
+
advantage = self.advantage_layer(conv_output)
|
426 |
+
q_value = value + (advantage - advantage.mean())
|
427 |
+
|
428 |
+
return q_value
|
429 |
+
|
430 |
+
|
431 |
+
class DuelingDQNAgent:
|
432 |
+
def __init__(self,
|
433 |
+
state_dim,
|
434 |
+
action_dim,
|
435 |
+
save_dir,
|
436 |
+
checkpoint=None,
|
437 |
+
learning_rate=0.00025,
|
438 |
+
max_memory_size=100000,
|
439 |
+
batch_size=32,
|
440 |
+
exploration_rate=1,
|
441 |
+
exploration_rate_decay=0.9999999,
|
442 |
+
exploration_rate_min=0.1,
|
443 |
+
training_frequency=1,
|
444 |
+
learning_starts=1000,
|
445 |
+
target_network_sync_frequency=500,
|
446 |
+
reset_exploration_rate=False,
|
447 |
+
save_frequency=100000,
|
448 |
+
gamma=0.9,
|
449 |
+
load_replay_buffer=True):
|
450 |
+
self.state_dim = state_dim
|
451 |
+
self.action_dim = action_dim
|
452 |
+
self.max_memory_size = max_memory_size
|
453 |
+
self.memory = deque(maxlen=max_memory_size)
|
454 |
+
self.batch_size = batch_size
|
455 |
+
|
456 |
+
self.exploration_rate = exploration_rate
|
457 |
+
self.exploration_rate_decay = exploration_rate_decay
|
458 |
+
self.exploration_rate_min = exploration_rate_min
|
459 |
+
self.gamma = gamma
|
460 |
+
|
461 |
+
self.curr_step = 0
|
462 |
+
self.learning_starts = learning_starts # min. experiences before training
|
463 |
+
|
464 |
+
self.training_frequency = training_frequency # no. of experiences between updates to Q_online
|
465 |
+
self.target_network_sync_frequency = target_network_sync_frequency # no. of experiences between Q_target & Q_online sync
|
466 |
+
|
467 |
+
self.save_every = save_frequency # no. of experiences between saving the network
|
468 |
+
self.save_dir = save_dir
|
469 |
+
|
470 |
+
self.use_cuda = torch.cuda.is_available()
|
471 |
+
|
472 |
+
|
473 |
+
self.online_net = DuelingDQNet(self.state_dim, self.action_dim).float()
|
474 |
+
self.target_net = copy.deepcopy(self.online_net)
|
475 |
+
# Q_target parameters are frozen.
|
476 |
+
for p in self.target_net.parameters():
|
477 |
+
p.requires_grad = False
|
478 |
+
|
479 |
+
if self.use_cuda:
|
480 |
+
self.online_net = self.online_net(device='cuda')
|
481 |
+
self.target_net = self.target_net(device='cuda')
|
482 |
+
if checkpoint:
|
483 |
+
self.load(checkpoint, reset_exploration_rate, load_replay_buffer)
|
484 |
+
|
485 |
+
self.optimizer = torch.optim.AdamW(self.online_net.parameters(), lr=learning_rate, amsgrad=True)
|
486 |
+
self.loss_fn = torch.nn.SmoothL1Loss()
|
487 |
+
|
488 |
+
|
489 |
+
def act(self, state):
|
490 |
+
"""
|
491 |
+
Given a state, choose an epsilon-greedy action and update value of step.
|
492 |
+
|
493 |
+
Inputs:
|
494 |
+
state(LazyFrame): A single observation of the current state, dimension is (state_dim)
|
495 |
+
Outputs:
|
496 |
+
action_idx (int): An integer representing which action the agent will perform
|
497 |
+
"""
|
498 |
+
# EXPLORE
|
499 |
+
if np.random.rand() < self.exploration_rate:
|
500 |
+
action_idx = np.random.randint(self.action_dim)
|
501 |
+
|
502 |
+
# EXPLOIT
|
503 |
+
else:
|
504 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
505 |
+
state = state.unsqueeze(0)
|
506 |
+
action_values = self.online_net(state)
|
507 |
+
action_idx = torch.argmax(action_values, axis=1).item()
|
508 |
+
|
509 |
+
# decrease exploration_rate
|
510 |
+
self.exploration_rate *= self.exploration_rate_decay
|
511 |
+
self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)
|
512 |
+
|
513 |
+
# increment step
|
514 |
+
self.curr_step += 1
|
515 |
+
return action_idx
|
516 |
+
|
517 |
+
def cache(self, state, next_state, action, reward, done):
|
518 |
+
"""
|
519 |
+
Store the experience to self.memory (replay buffer)
|
520 |
+
|
521 |
+
Inputs:
|
522 |
+
state (LazyFrame),
|
523 |
+
next_state (LazyFrame),
|
524 |
+
action (int),
|
525 |
+
reward (float),
|
526 |
+
done(bool))
|
527 |
+
"""
|
528 |
+
state = torch.FloatTensor(state).cuda() if self.use_cuda else torch.FloatTensor(state)
|
529 |
+
next_state = torch.FloatTensor(next_state).cuda() if self.use_cuda else torch.FloatTensor(next_state)
|
530 |
+
action = torch.LongTensor([action]).cuda() if self.use_cuda else torch.LongTensor([action])
|
531 |
+
reward = torch.DoubleTensor([reward]).cuda() if self.use_cuda else torch.DoubleTensor([reward])
|
532 |
+
done = torch.BoolTensor([done]).cuda() if self.use_cuda else torch.BoolTensor([done])
|
533 |
+
|
534 |
+
self.memory.append( (state, next_state, action, reward, done,) )
|
535 |
+
|
536 |
+
|
537 |
+
def recall(self):
|
538 |
+
"""
|
539 |
+
Retrieve a batch of experiences from memory
|
540 |
+
"""
|
541 |
+
batch = random.sample(self.memory, self.batch_size)
|
542 |
+
state, next_state, action, reward, done = map(torch.stack, zip(*batch))
|
543 |
+
return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()
|
544 |
+
|
545 |
+
|
546 |
+
def td_estimate(self, states, actions):
|
547 |
+
actions = actions.reshape(-1, 1)
|
548 |
+
predicted_qs = self.online_net(states)# Q_online(s,a)
|
549 |
+
predicted_qs = predicted_qs.gather(1, actions)
|
550 |
+
return predicted_qs
|
551 |
+
|
552 |
+
|
553 |
+
@torch.no_grad()
|
554 |
+
def td_target(self, rewards, next_states, dones):
|
555 |
+
rewards = rewards.reshape(-1, 1)
|
556 |
+
dones = dones.reshape(-1, 1)
|
557 |
+
target_qs = self.target_net.forward(next_states)
|
558 |
+
target_qs = torch.max(target_qs, dim=1).values
|
559 |
+
target_qs = target_qs.reshape(-1, 1)
|
560 |
+
target_qs[dones] = 0.0
|
561 |
+
return (rewards + (self.gamma * target_qs))
|
562 |
+
|
563 |
+
def update_Q_online(self, td_estimate, td_target) :
|
564 |
+
loss = self.loss_fn(td_estimate, td_target)
|
565 |
+
self.optimizer.zero_grad()
|
566 |
+
loss.backward()
|
567 |
+
self.optimizer.step()
|
568 |
+
return loss.item()
|
569 |
+
|
570 |
+
|
571 |
+
def sync_Q_target(self):
|
572 |
+
self.target_net.load_state_dict(self.online_net.state_dict())
|
573 |
+
|
574 |
+
|
575 |
+
def learn(self):
|
576 |
+
if self.curr_step % self.target_network_sync_frequency == 0:
|
577 |
+
self.sync_Q_target()
|
578 |
+
|
579 |
+
if self.curr_step % self.save_every == 0:
|
580 |
+
self.save()
|
581 |
+
|
582 |
+
if self.curr_step < self.learning_starts:
|
583 |
+
return None, None
|
584 |
+
|
585 |
+
if self.curr_step % self.training_frequency != 0:
|
586 |
+
return None, None
|
587 |
+
|
588 |
+
# Sample from memory
|
589 |
+
state, next_state, action, reward, done = self.recall()
|
590 |
+
|
591 |
+
# Get TD Estimate
|
592 |
+
td_est = self.td_estimate(state, action)
|
593 |
+
|
594 |
+
# Get TD Target
|
595 |
+
td_tgt = self.td_target(reward, next_state, done)
|
596 |
+
|
597 |
+
# Backpropagate loss through Q_online
|
598 |
+
loss = self.update_Q_online(td_est, td_tgt)
|
599 |
+
|
600 |
+
return (td_est.mean().item(), loss)
|
601 |
+
|
602 |
+
|
603 |
+
def save(self):
|
604 |
+
save_path = self.save_dir / f"airstriker_net_{int(self.curr_step // self.save_every)}.chkpt"
|
605 |
+
torch.save(
|
606 |
+
dict(
|
607 |
+
model=self.online_net.state_dict(),
|
608 |
+
exploration_rate=self.exploration_rate,
|
609 |
+
replay_memory=self.memory
|
610 |
+
),
|
611 |
+
save_path
|
612 |
+
)
|
613 |
+
|
614 |
+
print(f"Airstriker model saved to {save_path} at step {self.curr_step}")
|
615 |
+
|
616 |
+
|
617 |
+
def load(self, load_path, reset_exploration_rate, load_replay_buffer):
|
618 |
+
if not load_path.exists():
|
619 |
+
raise ValueError(f"{load_path} does not exist")
|
620 |
+
|
621 |
+
ckp = torch.load(load_path, map_location=('cuda' if self.use_cuda else 'cpu'))
|
622 |
+
exploration_rate = ckp.get('exploration_rate')
|
623 |
+
state_dict = ckp.get('model')
|
624 |
+
|
625 |
+
|
626 |
+
print(f"Loading model at {load_path} with exploration rate {exploration_rate}")
|
627 |
+
self.online_net.load_state_dict(state_dict)
|
628 |
+
self.target_net = copy.deepcopy(self.online_net)
|
629 |
+
self.sync_Q_target()
|
630 |
+
|
631 |
+
if load_replay_buffer:
|
632 |
+
replay_memory = ckp.get('replay_memory')
|
633 |
+
print(f"Loading replay memory. Len {len(replay_memory)}" if replay_memory else "Saved replay memory not found. Not restoring replay memory.")
|
634 |
+
self.memory = replay_memory if replay_memory else self.memory
|
635 |
+
|
636 |
+
if reset_exploration_rate:
|
637 |
+
print(f"Reset exploration rate option specified. Not restoring saved exploration rate {exploration_rate}. The current exploration rate is {self.exploration_rate}")
|
638 |
+
else:
|
639 |
+
print(f"Setting exploration rate to {exploration_rate} not loaded.")
|
640 |
+
self.exploration_rate = exploration_rate
|
641 |
+
|
642 |
+
|
643 |
+
|
644 |
+
|
645 |
+
class DuelingDDQNAgent(DuelingDQNAgent):
|
646 |
+
@torch.no_grad()
|
647 |
+
def td_target(self, rewards, next_states, dones):
|
648 |
+
rewards = rewards.reshape(-1, 1)
|
649 |
+
dones = dones.reshape(-1, 1)
|
650 |
+
q_vals = self.online_net.forward(next_states)
|
651 |
+
target_actions = torch.argmax(q_vals, axis=1)
|
652 |
+
target_actions = target_actions.reshape(-1, 1)
|
653 |
+
|
654 |
+
target_qs = self.target_net.forward(next_states)
|
655 |
+
target_qs = target_qs.gather(1, target_actions)
|
656 |
+
target_qs = target_qs.reshape(-1, 1)
|
657 |
+
target_qs[dones] = 0.0
|
658 |
+
return (rewards + (self.gamma * target_qs))
|
659 |
+
|
660 |
+
|
661 |
+
|
662 |
+
|
663 |
+
|
664 |
+
|
src/procgen/run-starpilot-ddqn.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DDQNAgent, MetricLogger
|
6 |
+
from wrappers import make_starpilot
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
|
10 |
+
|
11 |
+
env = make_starpilot()
|
12 |
+
|
13 |
+
use_cuda = torch.cuda.is_available()
|
14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
15 |
+
|
16 |
+
checkpoint = None
|
17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
18 |
+
|
19 |
+
path = "checkpoints/procgen-starpilot-ddqn"
|
20 |
+
save_dir = Path(path)
|
21 |
+
|
22 |
+
isExist = os.path.exists(path)
|
23 |
+
if not isExist:
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
logger = MetricLogger(save_dir)
|
27 |
+
|
28 |
+
print("Training DDQN Agent!")
|
29 |
+
agent = DDQNAgent(
|
30 |
+
state_dim=(1, 64, 64),
|
31 |
+
action_dim=env.action_space.n,
|
32 |
+
save_dir=save_dir,
|
33 |
+
batch_size=256,
|
34 |
+
checkpoint=checkpoint,
|
35 |
+
exploration_rate_decay=0.999995,
|
36 |
+
exploration_rate_min=0.05,
|
37 |
+
training_frequency=1,
|
38 |
+
target_network_sync_frequency=200,
|
39 |
+
max_memory_size=50000,
|
40 |
+
learning_rate=0.0005,
|
41 |
+
|
42 |
+
)
|
43 |
+
|
44 |
+
fill_memory(agent, env, 300)
|
45 |
+
train(agent, env, logger)
|
src/procgen/run-starpilot-dqn.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DQNAgent, MetricLogger
|
6 |
+
from wrappers import make_starpilot
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
|
10 |
+
|
11 |
+
env = make_starpilot()
|
12 |
+
|
13 |
+
use_cuda = torch.cuda.is_available()
|
14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
15 |
+
|
16 |
+
checkpoint = None
|
17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
18 |
+
|
19 |
+
path = "checkpoints/procgen-starpilot-dqn"
|
20 |
+
save_dir = Path(path)
|
21 |
+
|
22 |
+
isExist = os.path.exists(path)
|
23 |
+
if not isExist:
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
logger = MetricLogger(save_dir)
|
27 |
+
|
28 |
+
print("Training Vanilla DQN Agent!")
|
29 |
+
agent = DQNAgent(
|
30 |
+
state_dim=(1, 64, 64),
|
31 |
+
action_dim=env.action_space.n,
|
32 |
+
save_dir=save_dir,
|
33 |
+
batch_size=256,
|
34 |
+
checkpoint=checkpoint,
|
35 |
+
exploration_rate_decay=0.999995,
|
36 |
+
exploration_rate_min=0.05,
|
37 |
+
training_frequency=1,
|
38 |
+
target_network_sync_frequency=200,
|
39 |
+
max_memory_size=50000,
|
40 |
+
learning_rate=0.0005,
|
41 |
+
|
42 |
+
)
|
43 |
+
|
44 |
+
fill_memory(agent, env, 300)
|
45 |
+
train(agent, env, logger)
|
src/procgen/run-starpilot-dueling-ddqn.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DuelingDDQNAgent, MetricLogger
|
6 |
+
from wrappers import make_starpilot
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
|
10 |
+
|
11 |
+
env = make_starpilot()
|
12 |
+
|
13 |
+
use_cuda = torch.cuda.is_available()
|
14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
15 |
+
|
16 |
+
checkpoint = None
|
17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
18 |
+
|
19 |
+
path = "checkpoints/procgen-starpilot-dueling-ddqn"
|
20 |
+
save_dir = Path(path)
|
21 |
+
|
22 |
+
isExist = os.path.exists(path)
|
23 |
+
if not isExist:
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
logger = MetricLogger(save_dir)
|
27 |
+
|
28 |
+
print("Training Dueling Double DQN Agent!")
|
29 |
+
agent = DuelingDDQNAgent(
|
30 |
+
state_dim=(1, 64, 64),
|
31 |
+
action_dim=env.action_space.n,
|
32 |
+
save_dir=save_dir,
|
33 |
+
batch_size=256,
|
34 |
+
checkpoint=checkpoint,
|
35 |
+
exploration_rate_decay=0.999995,
|
36 |
+
exploration_rate_min=0.05,
|
37 |
+
training_frequency=1,
|
38 |
+
target_network_sync_frequency=200,
|
39 |
+
max_memory_size=50000,
|
40 |
+
learning_rate=0.0005,
|
41 |
+
|
42 |
+
)
|
43 |
+
|
44 |
+
# fill_memory(agent, env, 300)
|
45 |
+
train(agent, env, logger)
|
src/procgen/run-starpilot-dueling-dqn.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from agent import DuelingDQNAgent, MetricLogger
|
6 |
+
from wrappers import make_starpilot
|
7 |
+
import os
|
8 |
+
from train import train, fill_memory
|
9 |
+
|
10 |
+
|
11 |
+
env = make_starpilot()
|
12 |
+
|
13 |
+
use_cuda = torch.cuda.is_available()
|
14 |
+
print(f"Using CUDA: {use_cuda}\n")
|
15 |
+
|
16 |
+
checkpoint = None
|
17 |
+
# checkpoint = Path('checkpoints/latest/airstriker_net_3.chkpt')
|
18 |
+
|
19 |
+
path = "checkpoints/procgen-starpilot-dueling-dqn"
|
20 |
+
save_dir = Path(path)
|
21 |
+
|
22 |
+
isExist = os.path.exists(path)
|
23 |
+
if not isExist:
|
24 |
+
os.makedirs(path)
|
25 |
+
|
26 |
+
logger = MetricLogger(save_dir)
|
27 |
+
|
28 |
+
print("Training Dueling DQN Agent!")
|
29 |
+
agent = DuelingDQNAgent(
|
30 |
+
state_dim=(1, 64, 64),
|
31 |
+
action_dim=env.action_space.n,
|
32 |
+
save_dir=save_dir,
|
33 |
+
batch_size=256,
|
34 |
+
checkpoint=checkpoint,
|
35 |
+
exploration_rate_decay=0.999995,
|
36 |
+
exploration_rate_min=0.05,
|
37 |
+
training_frequency=1,
|
38 |
+
target_network_sync_frequency=200,
|
39 |
+
max_memory_size=50000,
|
40 |
+
learning_rate=0.0005,
|
41 |
+
|
42 |
+
)
|
43 |
+
|
44 |
+
# fill_memory(agent, env, 300)
|
45 |
+
train(agent, env, logger)
|
src/procgen/test-procgen.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
env = gym.make("procgen:procgen-starpilot-v0")
|
3 |
+
|
4 |
+
obs = env.reset()
|
5 |
+
step = 0
|
6 |
+
while True:
|
7 |
+
obs, rew, done, info = env.step(env.action_space.sample())
|
8 |
+
print(info)
|
9 |
+
print(f"step {step} reward {rew} done {done}")
|
10 |
+
step += 1
|
11 |
+
if done:
|
12 |
+
break
|
src/procgen/train.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import trange
|
2 |
+
|
3 |
+
def fill_memory(agent, env, num_episodes=500 ):
|
4 |
+
print("Filling up memory....")
|
5 |
+
for _ in trange(num_episodes):
|
6 |
+
state = env.reset()
|
7 |
+
done = False
|
8 |
+
while not done:
|
9 |
+
action = agent.act(state)
|
10 |
+
next_state, reward, done, _ = env.step(action)
|
11 |
+
agent.cache(state, next_state, action, reward, done)
|
12 |
+
state = next_state
|
13 |
+
|
14 |
+
|
15 |
+
def train(agent, env, logger):
|
16 |
+
episodes = 5000
|
17 |
+
for e in range(episodes):
|
18 |
+
|
19 |
+
state = env.reset()
|
20 |
+
# Play the game!
|
21 |
+
while True:
|
22 |
+
|
23 |
+
# Run agent on the state
|
24 |
+
action = agent.act(state)
|
25 |
+
|
26 |
+
# Agent performs action
|
27 |
+
next_state, reward, done, info = env.step(action)
|
28 |
+
|
29 |
+
# Remember
|
30 |
+
agent.cache(state, next_state, action, reward, done)
|
31 |
+
|
32 |
+
# Learn
|
33 |
+
q, loss = agent.learn()
|
34 |
+
|
35 |
+
# Logging
|
36 |
+
logger.log_step(reward, loss, q)
|
37 |
+
|
38 |
+
# Update state
|
39 |
+
state = next_state
|
40 |
+
|
41 |
+
# Check if end of game
|
42 |
+
if done:
|
43 |
+
break
|
44 |
+
|
45 |
+
logger.log_episode(e)
|
46 |
+
|
47 |
+
if e % 20 == 0:
|
48 |
+
logger.record(episode=e, epsilon=agent.exploration_rate, step=agent.curr_step)
|
src/procgen/wrappers.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
from collections import deque
|
4 |
+
import gym
|
5 |
+
from gym import spaces
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
'''
|
10 |
+
Atari Wrapper copied from https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
|
11 |
+
'''
|
12 |
+
|
13 |
+
|
14 |
+
class LazyFrames(object):
|
15 |
+
def __init__(self, frames):
|
16 |
+
"""This object ensures that common frames between the observations are only stored once.
|
17 |
+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
|
18 |
+
buffers.
|
19 |
+
This object should only be converted to numpy array before being passed to the model.
|
20 |
+
You'd not believe how complex the previous solution was."""
|
21 |
+
self._frames = frames
|
22 |
+
self._out = None
|
23 |
+
|
24 |
+
def _force(self):
|
25 |
+
if self._out is None:
|
26 |
+
self._out = np.concatenate(self._frames, axis=2)
|
27 |
+
self._frames = None
|
28 |
+
return self._out
|
29 |
+
|
30 |
+
def __array__(self, dtype=None):
|
31 |
+
out = self._force()
|
32 |
+
if dtype is not None:
|
33 |
+
out = out.astype(dtype)
|
34 |
+
return out
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self._force())
|
38 |
+
|
39 |
+
def __getitem__(self, i):
|
40 |
+
return self._force()[i]
|
41 |
+
|
42 |
+
class FireResetEnv(gym.Wrapper):
|
43 |
+
def __init__(self, env):
|
44 |
+
"""Take action on reset for environments that are fixed until firing."""
|
45 |
+
gym.Wrapper.__init__(self, env)
|
46 |
+
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
|
47 |
+
assert len(env.unwrapped.get_action_meanings()) >= 3
|
48 |
+
|
49 |
+
def reset(self, **kwargs):
|
50 |
+
self.env.reset(**kwargs)
|
51 |
+
obs, _, done, _ = self.env.step(1)
|
52 |
+
if done:
|
53 |
+
self.env.reset(**kwargs)
|
54 |
+
obs, _, done, _ = self.env.step(2)
|
55 |
+
if done:
|
56 |
+
self.env.reset(**kwargs)
|
57 |
+
return obs
|
58 |
+
|
59 |
+
def step(self, ac):
|
60 |
+
return self.env.step(ac)
|
61 |
+
|
62 |
+
|
63 |
+
class MaxAndSkipEnv(gym.Wrapper):
|
64 |
+
def __init__(self, env, skip=4):
|
65 |
+
"""Return only every `skip`-th frame"""
|
66 |
+
gym.Wrapper.__init__(self, env)
|
67 |
+
# most recent raw observations (for max pooling across time steps)
|
68 |
+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
|
69 |
+
self._skip = skip
|
70 |
+
|
71 |
+
def step(self, action):
|
72 |
+
"""Repeat action, sum reward, and max over last observations."""
|
73 |
+
total_reward = 0.0
|
74 |
+
done = None
|
75 |
+
for i in range(self._skip):
|
76 |
+
obs, reward, done, info = self.env.step(action)
|
77 |
+
if i == self._skip - 2: self._obs_buffer[0] = obs
|
78 |
+
if i == self._skip - 1: self._obs_buffer[1] = obs
|
79 |
+
total_reward += reward
|
80 |
+
if done:
|
81 |
+
break
|
82 |
+
# Note that the observation on the done=True frame
|
83 |
+
# doesn't matter
|
84 |
+
max_frame = self._obs_buffer.max(axis=0)
|
85 |
+
|
86 |
+
return max_frame, total_reward, done, info
|
87 |
+
|
88 |
+
def reset(self, **kwargs):
|
89 |
+
return self.env.reset(**kwargs)
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
class WarpFrame(gym.ObservationWrapper):
|
94 |
+
def __init__(self, env):
|
95 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
96 |
+
gym.ObservationWrapper.__init__(self, env)
|
97 |
+
self.width = 84
|
98 |
+
self.height = 84
|
99 |
+
self.observation_space = spaces.Box(low=0, high=255,
|
100 |
+
shape=(self.height, self.width, 1), dtype=np.uint8)
|
101 |
+
|
102 |
+
def observation(self, frame):
|
103 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
104 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
105 |
+
return frame[:, :, None]
|
106 |
+
|
107 |
+
class WarpFrameNoResize(gym.ObservationWrapper):
|
108 |
+
def __init__(self, env):
|
109 |
+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
|
110 |
+
gym.ObservationWrapper.__init__(self, env)
|
111 |
+
|
112 |
+
def observation(self, frame):
|
113 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
114 |
+
# frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
115 |
+
return frame[:, :, None]
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
class FrameStack(gym.Wrapper):
|
120 |
+
def __init__(self, env, k):
|
121 |
+
"""Stack k last frames.
|
122 |
+
Returns lazy array, which is much more memory efficient.
|
123 |
+
See Also
|
124 |
+
--------
|
125 |
+
baselines.common.atari_wrappers.LazyFrames
|
126 |
+
"""
|
127 |
+
gym.Wrapper.__init__(self, env)
|
128 |
+
self.k = k
|
129 |
+
self.frames = deque([], maxlen=k)
|
130 |
+
shp = env.observation_space.shape
|
131 |
+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
|
132 |
+
|
133 |
+
def reset(self):
|
134 |
+
ob = self.env.reset()
|
135 |
+
for _ in range(self.k):
|
136 |
+
self.frames.append(ob)
|
137 |
+
return self._get_ob()
|
138 |
+
|
139 |
+
def step(self, action):
|
140 |
+
ob, reward, done, info = self.env.step(action)
|
141 |
+
self.frames.append(ob)
|
142 |
+
return self._get_ob(), reward, done, info
|
143 |
+
|
144 |
+
def _get_ob(self):
|
145 |
+
assert len(self.frames) == self.k
|
146 |
+
return LazyFrames(list(self.frames))
|
147 |
+
|
148 |
+
|
149 |
+
class ImageToPyTorch(gym.ObservationWrapper):
|
150 |
+
def __init__(self, env):
|
151 |
+
super(ImageToPyTorch, self).__init__(env)
|
152 |
+
old_shape = self.observation_space.shape
|
153 |
+
self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]), dtype=np.float32)
|
154 |
+
|
155 |
+
def observation(self, observation):
|
156 |
+
return np.moveaxis(observation, 2, 0)
|
157 |
+
|
158 |
+
|
159 |
+
class ScaledFloatFrame(gym.ObservationWrapper):
|
160 |
+
def __init__(self, env):
|
161 |
+
gym.ObservationWrapper.__init__(self, env)
|
162 |
+
self.observation_space = gym.spaces.Box(low=0, high=1, shape=env.observation_space.shape, dtype=np.float32)
|
163 |
+
|
164 |
+
def observation(self, observation):
|
165 |
+
# careful! This undoes the memory optimization, use
|
166 |
+
# with smaller replay buffers only.
|
167 |
+
return np.array(observation).astype(np.float32) / 255.0
|
168 |
+
|
169 |
+
class ClipRewardEnv(gym.RewardWrapper):
|
170 |
+
def __init__(self, env):
|
171 |
+
gym.RewardWrapper.__init__(self, env)
|
172 |
+
|
173 |
+
def reward(self, reward):
|
174 |
+
"""Bin reward to {+1, 0, -1} by its sign."""
|
175 |
+
return np.sign(reward)
|
176 |
+
|
177 |
+
|
178 |
+
def make_starpilot(render=False):
|
179 |
+
print("Environment: Starpilot")
|
180 |
+
if render:
|
181 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy", render_mode="human")
|
182 |
+
else:
|
183 |
+
env = gym.make("procgen:procgen-starpilot-v0", distribution_mode="easy")
|
184 |
+
env = WarpFrameNoResize(env) ## Reshape image
|
185 |
+
env = ImageToPyTorch(env) ## Invert shape
|
186 |
+
env = FrameStack(env, 4) ## Stack last 4 frames
|
187 |
+
return env
|
troubleshooting.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ml-reinforcement-learning
|
2 |
+
|
3 |
+
Python version: 3.7.3
|
4 |
+
|
5 |
+
|
6 |
+
Troubleshooting
|
7 |
+
|
8 |
+
|
9 |
+
- RuntimeError: Polyfit sanity test emitted a warning, most likely due to using a buggy Accelerate backend. If you compiled yourself, more information is available at https://numpy.org/doc/stable/user/building.html#accelerated-blas-lapack-libraries Otherwise report this to the vendor that provided NumPy.
|
10 |
+
RankWarning: Polyfit may be poorly conditioned
|
11 |
+
|
12 |
+
```
|
13 |
+
$ pip uninstall numpy
|
14 |
+
$ export OPENBLAS=$(brew --prefix openblas)
|
15 |
+
$ pip install --no-cache-dir numpy
|
16 |
+
```
|
17 |
+
|
18 |
+
|
19 |
+
During grpcio installation 👇
|
20 |
+
distutils.errors.CompileError: command 'clang' failed with exit status 1
|
21 |
+
```
|
22 |
+
CFLAGS="-I/Library/Developer/CommandLineTools/usr/include/c++/v1 -I/opt/homebrew/opt/openssl/include" LDFLAGS="-L/opt/homebrew/opt/openssl/lib" pip3 install grpcio
|
23 |
+
```
|
24 |
+
|
25 |
+
|
26 |
+
ModuleNotFoundError: No module named 'gym.envs.classic_control.rendering'
|
27 |
+
|
28 |
+
|
29 |
+
#Setup
|
30 |
+
|
31 |
+
```
|
32 |
+
conda install pytorch torchvision -c pytorch
|
33 |
+
pip install gym-retro
|
34 |
+
conda install numpy
|
35 |
+
pip install "gym[atari]==0.21.0"
|
36 |
+
pip install importlib-metadata==4.13.0
|
37 |
+
```
|