Spaces:
Sleeping
Sleeping
Update train.py
Browse files
train.py
CHANGED
@@ -6,10 +6,10 @@ from callbacks import SaveFramesCallback
|
|
6 |
import os
|
7 |
|
8 |
def main():
|
9 |
-
|
10 |
env = TetrisEnv()
|
11 |
|
12 |
-
|
13 |
model = DQN('MlpPolicy', env, verbose=1,
|
14 |
learning_rate=1e-3,
|
15 |
buffer_size=50000,
|
@@ -20,21 +20,21 @@ def main():
|
|
20 |
exploration_fraction=0.1,
|
21 |
exploration_final_eps=0.02)
|
22 |
|
23 |
-
|
24 |
-
TIMESTEPS = 550000
|
25 |
|
26 |
-
|
27 |
callback = SaveFramesCallback(save_freq=5000, save_path="models/frames", verbose=1)
|
28 |
|
29 |
-
|
30 |
model.learn(total_timesteps=TIMESTEPS, callback=callback)
|
31 |
|
32 |
-
|
33 |
os.makedirs("models", exist_ok=True)
|
34 |
model.save("models/dqn_tetris")
|
35 |
print("Model saved to models/dqn_tetris.zip")
|
36 |
|
37 |
-
|
38 |
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
|
39 |
print(f"Mean Reward: {mean_reward} +/- {std_reward}")
|
40 |
|
|
|
6 |
import os
|
7 |
|
8 |
def main():
|
9 |
+
|
10 |
env = TetrisEnv()
|
11 |
|
12 |
+
|
13 |
model = DQN('MlpPolicy', env, verbose=1,
|
14 |
learning_rate=1e-3,
|
15 |
buffer_size=50000,
|
|
|
20 |
exploration_fraction=0.1,
|
21 |
exploration_final_eps=0.02)
|
22 |
|
23 |
+
|
24 |
+
TIMESTEPS = 550000
|
25 |
|
26 |
+
|
27 |
callback = SaveFramesCallback(save_freq=5000, save_path="models/frames", verbose=1)
|
28 |
|
29 |
+
|
30 |
model.learn(total_timesteps=TIMESTEPS, callback=callback)
|
31 |
|
32 |
+
|
33 |
os.makedirs("models", exist_ok=True)
|
34 |
model.save("models/dqn_tetris")
|
35 |
print("Model saved to models/dqn_tetris.zip")
|
36 |
|
37 |
+
|
38 |
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
|
39 |
print(f"Mean Reward: {mean_reward} +/- {std_reward}")
|
40 |
|