BaljinderH commited on
Commit
880396d
·
verified ·
1 Parent(s): d467399

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +8 -8
train.py CHANGED
@@ -6,10 +6,10 @@ from callbacks import SaveFramesCallback
6
  import os
7
 
8
  def main():
9
- # Create the environment
10
  env = TetrisEnv()
11
 
12
- # Initialize the RL model (DQN)
13
  model = DQN('MlpPolicy', env, verbose=1,
14
  learning_rate=1e-3,
15
  buffer_size=50000,
@@ -20,21 +20,21 @@ def main():
20
  exploration_fraction=0.1,
21
  exploration_final_eps=0.02)
22
 
23
- # Define the number of training timesteps
24
- TIMESTEPS = 550000 # Adjust as needed
25
 
26
- # Initialize the callback
27
  callback = SaveFramesCallback(save_freq=5000, save_path="models/frames", verbose=1)
28
 
29
- # Train the model with the callback
30
  model.learn(total_timesteps=TIMESTEPS, callback=callback)
31
 
32
- # Save the model
33
  os.makedirs("models", exist_ok=True)
34
  model.save("models/dqn_tetris")
35
  print("Model saved to models/dqn_tetris.zip")
36
 
37
- # Evaluate the trained agent
38
  mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
39
  print(f"Mean Reward: {mean_reward} +/- {std_reward}")
40
 
 
6
  import os
7
 
8
  def main():
9
+
10
  env = TetrisEnv()
11
 
12
+
13
  model = DQN('MlpPolicy', env, verbose=1,
14
  learning_rate=1e-3,
15
  buffer_size=50000,
 
20
  exploration_fraction=0.1,
21
  exploration_final_eps=0.02)
22
 
23
+
24
+ TIMESTEPS = 550000
25
 
26
+
27
  callback = SaveFramesCallback(save_freq=5000, save_path="models/frames", verbose=1)
28
 
29
+
30
  model.learn(total_timesteps=TIMESTEPS, callback=callback)
31
 
32
+
33
  os.makedirs("models", exist_ok=True)
34
  model.save("models/dqn_tetris")
35
  print("Model saved to models/dqn_tetris.zip")
36
 
37
+
38
  mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
39
  print(f"Mean Reward: {mean_reward} +/- {std_reward}")
40