Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- callbacks.py +32 -0
- compile_video.py +16 -0
- evaluate.py +41 -0
- push_to_hub.py +39 -0
- requirements.txt +8 -0
- sandtris.py +123 -0
- tetris_env.py +210 -0
- train.py +42 -0
callbacks.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from stable_baselines3.common.callbacks import BaseCallback
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class SaveFramesCallback(BaseCallback):
|
7 |
+
"""
|
8 |
+
Callback for saving frames during training.
|
9 |
+
"""
|
10 |
+
def __init__(self, save_freq, save_path, verbose=0):
|
11 |
+
super(SaveFramesCallback, self).__init__(verbose)
|
12 |
+
self.save_freq = save_freq
|
13 |
+
self.save_path = save_path
|
14 |
+
self.frames = []
|
15 |
+
os.makedirs(self.save_path, exist_ok=True)
|
16 |
+
|
17 |
+
def _on_step(self) -> bool:
|
18 |
+
if self.num_timesteps % self.save_freq == 0:
|
19 |
+
# Render the environment and get the RGB array
|
20 |
+
frame = self.training_env.render(mode='rgb_array')
|
21 |
+
self.frames.append(frame)
|
22 |
+
if self.verbose > 0:
|
23 |
+
print(f"Saved frame at timestep {self.num_timesteps}")
|
24 |
+
return True
|
25 |
+
|
26 |
+
def _on_training_end(self) -> None:
|
27 |
+
# Save the frames as a GIF
|
28 |
+
if self.frames:
|
29 |
+
gif_path = os.path.join(self.save_path, "training.gif")
|
30 |
+
imageio.mimsave(gif_path, self.frames, fps=10)
|
31 |
+
if self.verbose > 0:
|
32 |
+
print(f"Saved training GIF to {gif_path}")
|
compile_video.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import imageio
|
2 |
+
import os
|
3 |
+
|
4 |
+
def compile_gif_to_video(gif_path, video_path):
|
5 |
+
reader = imageio.get_reader(gif_path)
|
6 |
+
fps = 10
|
7 |
+
writer = imageio.get_writer(video_path, fps=fps)
|
8 |
+
for frame in reader:
|
9 |
+
writer.append_data(frame)
|
10 |
+
writer.close()
|
11 |
+
print(f"Video saved to {video_path}")
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
gif_path = "models/frames/training.gif"
|
15 |
+
video_path = "models/training.mp4"
|
16 |
+
compile_gif_to_video(gif_path, video_path)
|
evaluate.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
from stable_baselines3 import DQN
|
3 |
+
from tetris_env import TetrisEnv
|
4 |
+
import pygame
|
5 |
+
import time
|
6 |
+
|
7 |
+
def main():
|
8 |
+
# Create the environment
|
9 |
+
env = TetrisEnv()
|
10 |
+
|
11 |
+
# Load the trained model
|
12 |
+
model = DQN.load("models/dqn_tetris")
|
13 |
+
|
14 |
+
# Number of evaluation episodes
|
15 |
+
episodes = 5
|
16 |
+
|
17 |
+
for ep in range(1, episodes + 1):
|
18 |
+
obs = env.reset()
|
19 |
+
done = False
|
20 |
+
total_reward = 0
|
21 |
+
while not done:
|
22 |
+
# Render the game (optional)
|
23 |
+
env.render(mode='human')
|
24 |
+
|
25 |
+
# Predict the action using the trained model
|
26 |
+
action, _states = model.predict(obs, deterministic=True)
|
27 |
+
|
28 |
+
# Take the action in the environment
|
29 |
+
obs, reward, done, info = env.step(action)
|
30 |
+
|
31 |
+
total_reward += reward
|
32 |
+
|
33 |
+
# Control the rendering speed
|
34 |
+
pygame.time.wait(100) # Wait 100 ms between steps
|
35 |
+
|
36 |
+
print(f"Episode {ep}: Total Reward = {total_reward}")
|
37 |
+
|
38 |
+
env.close()
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
main()
|
push_to_hub.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from stable_baselines3 import DQN
|
2 |
+
from huggingface_hub import HfApi, HfFolder, Repository
|
3 |
+
import os
|
4 |
+
|
5 |
+
def main():
|
6 |
+
# Define repository details
|
7 |
+
repo_name = "dqn-tetris"
|
8 |
+
model_path = "models/dqn_tetris.zip"
|
9 |
+
model_dir = "models"
|
10 |
+
|
11 |
+
# Load the trained model
|
12 |
+
model = DQN.load(model_path)
|
13 |
+
|
14 |
+
# Initialize Hugging Face API
|
15 |
+
api = HfApi()
|
16 |
+
user = api.whoami()["name"]
|
17 |
+
|
18 |
+
# Create the repository if it doesn't exist
|
19 |
+
try:
|
20 |
+
api.create_repo(name=repo_name, repo_type="model", exist_ok=True)
|
21 |
+
print(f"Repository '{repo_name}' created.")
|
22 |
+
except Exception as e:
|
23 |
+
print(f"Repository '{repo_name}' already exists or failed to create: {e}")
|
24 |
+
|
25 |
+
# Clone the repository locally
|
26 |
+
repo = Repository(local_dir=repo_name, clone_from=f"{user}/{repo_name}", use_auth_token=True)
|
27 |
+
|
28 |
+
# Copy the model file into the repository directory
|
29 |
+
os.makedirs(repo_name, exist_ok=True)
|
30 |
+
os.rename(model_path, os.path.join(repo_name, "dqn_tetris.zip"))
|
31 |
+
|
32 |
+
# Add and commit the model
|
33 |
+
repo.git_add(auto_lfs_track=True)
|
34 |
+
repo.git_commit("Add trained DQN Tetris model")
|
35 |
+
repo.git_push()
|
36 |
+
print(f"Model pushed to Hugging Face Hub at {user}/{repo_name}")
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
main()
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pygame
|
2 |
+
gym
|
3 |
+
stable-baselines3
|
4 |
+
numpy
|
5 |
+
torch
|
6 |
+
huggingface-hub
|
7 |
+
imageio
|
8 |
+
gradio
|
sandtris.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pygame
|
2 |
+
import random
|
3 |
+
|
4 |
+
colors = [
|
5 |
+
(0, 0, 0),
|
6 |
+
(120, 37, 179),
|
7 |
+
(100, 179, 179),
|
8 |
+
(80, 34, 22),
|
9 |
+
(80, 134, 22),
|
10 |
+
(180, 34, 22),
|
11 |
+
(180, 34, 122),
|
12 |
+
]
|
13 |
+
|
14 |
+
class Figure:
|
15 |
+
x = 0
|
16 |
+
y = 0
|
17 |
+
|
18 |
+
figures = [
|
19 |
+
[[1, 5, 9, 13], [4, 5, 6, 7]],
|
20 |
+
[[4, 5, 9, 10], [2, 6, 5, 9]],
|
21 |
+
[[6, 7, 9, 10], [1, 5, 6, 10]],
|
22 |
+
[[1, 2, 5, 9], [0, 4, 5, 6], [1, 5, 9, 8], [4, 5, 6, 10]],
|
23 |
+
[[1, 2, 6, 10], [5, 6, 7, 9], [2, 6, 10, 11], [3, 5, 6, 7]],
|
24 |
+
[[1, 4, 5, 6], [1, 4, 5, 9], [4, 5, 6, 9], [1, 5, 6, 9]],
|
25 |
+
[[1, 2, 5, 6]],
|
26 |
+
]
|
27 |
+
|
28 |
+
def __init__(self, x, y):
|
29 |
+
self.x = x
|
30 |
+
self.y = y
|
31 |
+
self.type = random.randint(0, len(self.figures) - 1)
|
32 |
+
self.color = random.randint(1, len(colors) - 1)
|
33 |
+
self.rotation = 0
|
34 |
+
|
35 |
+
def image(self):
|
36 |
+
return self.figures[self.type][self.rotation]
|
37 |
+
|
38 |
+
def rotate(self):
|
39 |
+
self.rotation = (self.rotation + 1) % len(self.figures[self.type])
|
40 |
+
|
41 |
+
class Tetris:
|
42 |
+
def __init__(self, height, width):
|
43 |
+
self.level = 2
|
44 |
+
self.score = 0
|
45 |
+
self.state = "start"
|
46 |
+
self.field = []
|
47 |
+
self.height = height
|
48 |
+
self.width = width
|
49 |
+
self.x = 100
|
50 |
+
self.y = 60
|
51 |
+
self.zoom = 20
|
52 |
+
self.figure = None
|
53 |
+
|
54 |
+
for i in range(height):
|
55 |
+
new_line = []
|
56 |
+
for j in range(width):
|
57 |
+
new_line.append(0)
|
58 |
+
self.field.append(new_line)
|
59 |
+
|
60 |
+
def new_figure(self):
|
61 |
+
self.figure = Figure(3, 0)
|
62 |
+
|
63 |
+
def intersects(self):
|
64 |
+
intersection = False
|
65 |
+
for i in range(4):
|
66 |
+
for j in range(4):
|
67 |
+
if i * 4 + j in self.figure.image():
|
68 |
+
if (
|
69 |
+
i + self.figure.y > self.height - 1
|
70 |
+
or j + self.figure.x > self.width - 1
|
71 |
+
or j + self.figure.x < 0
|
72 |
+
or self.field[i + self.figure.y][j + self.figure.x] > 0
|
73 |
+
):
|
74 |
+
intersection = True
|
75 |
+
return intersection
|
76 |
+
|
77 |
+
def break_lines(self):
|
78 |
+
lines = 0
|
79 |
+
for i in range(1, self.height):
|
80 |
+
zeros = 0
|
81 |
+
for j in range(self.width):
|
82 |
+
if self.field[i][j] == 0:
|
83 |
+
zeros += 1
|
84 |
+
if zeros == 0:
|
85 |
+
lines += 1
|
86 |
+
for i1 in range(i, 1, -1):
|
87 |
+
for j in range(self.width):
|
88 |
+
self.field[i1][j] = self.field[i1 - 1][j]
|
89 |
+
self.score += lines ** 2
|
90 |
+
|
91 |
+
def go_space(self):
|
92 |
+
while not self.intersects():
|
93 |
+
self.figure.y += 1
|
94 |
+
self.figure.y -= 1
|
95 |
+
self.freeze()
|
96 |
+
|
97 |
+
def go_down(self):
|
98 |
+
self.figure.y += 1
|
99 |
+
if self.intersects():
|
100 |
+
self.figure.y -= 1
|
101 |
+
self.freeze()
|
102 |
+
|
103 |
+
def freeze(self):
|
104 |
+
for i in range(4):
|
105 |
+
for j in range(4):
|
106 |
+
if i * 4 + j in self.figure.image():
|
107 |
+
self.field[i + self.figure.y][j + self.figure.x] = self.figure.color
|
108 |
+
self.break_lines()
|
109 |
+
self.new_figure()
|
110 |
+
if self.intersects():
|
111 |
+
self.state = "gameover"
|
112 |
+
|
113 |
+
def go_side(self, dx):
|
114 |
+
old_x = self.figure.x
|
115 |
+
self.figure.x += dx
|
116 |
+
if self.intersects():
|
117 |
+
self.figure.x = old_x
|
118 |
+
|
119 |
+
def rotate(self):
|
120 |
+
old_rotation = self.figure.rotation
|
121 |
+
self.figure.rotate()
|
122 |
+
if self.intersects():
|
123 |
+
self.figure.rotation = old_rotation
|
tetris_env.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
from gym import spaces
|
3 |
+
import numpy as np
|
4 |
+
import pygame
|
5 |
+
import random
|
6 |
+
from sandtris import Tetris # Ensure sandtris.py is in the same directory
|
7 |
+
import os
|
8 |
+
|
9 |
+
colors = [
|
10 |
+
(0, 0, 0),
|
11 |
+
(120, 37, 179),
|
12 |
+
(100, 179, 179),
|
13 |
+
(80, 34, 22),
|
14 |
+
(80, 134, 22),
|
15 |
+
(180, 34, 22),
|
16 |
+
(180, 34, 122),
|
17 |
+
]
|
18 |
+
|
19 |
+
class TetrisEnv(gym.Env):
|
20 |
+
"""
|
21 |
+
Custom Environment for Tetris game compatible with OpenAI Gym
|
22 |
+
"""
|
23 |
+
metadata = {'render.modes': ['human', 'rgb_array']}
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
super(TetrisEnv, self).__init__()
|
27 |
+
|
28 |
+
# Define action space: 0=left, 1=right, 2=rotate, 3=drop, 4=noop
|
29 |
+
self.action_space = spaces.Discrete(5)
|
30 |
+
|
31 |
+
# Observation space: 2D grid representing the game board
|
32 |
+
self.height = 20
|
33 |
+
self.width = 10
|
34 |
+
self.observation_space = spaces.Box(low=0, high=6,
|
35 |
+
shape=(self.height, self.width), dtype=np.int32)
|
36 |
+
|
37 |
+
# Initialize the game
|
38 |
+
self.game = Tetris(self.height, self.width)
|
39 |
+
|
40 |
+
# Setup for rendering
|
41 |
+
self.screen = None
|
42 |
+
self.zoom = 20
|
43 |
+
self.x = 100
|
44 |
+
self.y = 60
|
45 |
+
|
46 |
+
def reset(self):
|
47 |
+
"""
|
48 |
+
Reset the game to initial state
|
49 |
+
"""
|
50 |
+
self.game = Tetris(self.height, self.width)
|
51 |
+
self.game.new_figure()
|
52 |
+
return self._get_obs()
|
53 |
+
|
54 |
+
def step(self, action):
|
55 |
+
"""
|
56 |
+
Execute one time step within the environment
|
57 |
+
"""
|
58 |
+
done = False
|
59 |
+
reward = 0
|
60 |
+
|
61 |
+
# Apply action
|
62 |
+
if action == 0:
|
63 |
+
self.game.go_side(-1) # Move left
|
64 |
+
elif action == 1:
|
65 |
+
self.game.go_side(1) # Move right
|
66 |
+
elif action == 2:
|
67 |
+
self.game.rotate() # Rotate
|
68 |
+
elif action == 3:
|
69 |
+
self.game.go_space() # Drop
|
70 |
+
elif action == 4:
|
71 |
+
pass # No operation
|
72 |
+
|
73 |
+
# Move the piece down automatically
|
74 |
+
self.game.go_down()
|
75 |
+
|
76 |
+
# Calculate reward
|
77 |
+
lines_cleared = self.game.score # Assuming score increments with lines cleared
|
78 |
+
reward += lines_cleared * 10
|
79 |
+
|
80 |
+
# Additional reward shaping
|
81 |
+
aggregate_height = self.calculate_aggregate_height()
|
82 |
+
holes = self.calculate_holes()
|
83 |
+
bumpiness = self.calculate_bumpiness()
|
84 |
+
|
85 |
+
reward -= (aggregate_height * 0.5 + holes * 0.7 + bumpiness * 0.3)
|
86 |
+
|
87 |
+
if self.game.state == "gameover":
|
88 |
+
done = True
|
89 |
+
reward -= 10 # Penalty for losing
|
90 |
+
|
91 |
+
return self._get_obs(), reward, done, {}
|
92 |
+
|
93 |
+
def _get_obs(self):
|
94 |
+
"""
|
95 |
+
Get the current state of the game as an observation
|
96 |
+
"""
|
97 |
+
return np.array(self.game.field)
|
98 |
+
|
99 |
+
def render(self, mode='human'):
|
100 |
+
"""
|
101 |
+
Render the game state as an RGB array or display it
|
102 |
+
"""
|
103 |
+
if mode == 'rgb_array':
|
104 |
+
if self.screen is None:
|
105 |
+
pygame.init()
|
106 |
+
size = (self.x * 2 + self.zoom * self.width, self.y * 2 + self.zoom * self.height)
|
107 |
+
self.screen = pygame.Surface(size)
|
108 |
+
|
109 |
+
self.screen.fill((173, 216, 230)) # WHITE background
|
110 |
+
|
111 |
+
# Draw the game field
|
112 |
+
for i in range(self.game.height):
|
113 |
+
for j in range(self.game.width):
|
114 |
+
rect = pygame.Rect(self.x + self.zoom * j, self.y + self.zoom * i, self.zoom, self.zoom)
|
115 |
+
pygame.draw.rect(self.screen, (128, 128, 128), rect, 1) # Grid lines
|
116 |
+
if self.game.field[i][j] > 0:
|
117 |
+
pygame.draw.rect(self.screen,
|
118 |
+
colors[self.game.field[i][j]],
|
119 |
+
rect.inflate(-2, -2))
|
120 |
+
|
121 |
+
# Draw the current figure
|
122 |
+
if self.game.figure is not None:
|
123 |
+
for i in range(4):
|
124 |
+
for j in range(4):
|
125 |
+
p = i * 4 + j
|
126 |
+
if p in self.game.figure.image():
|
127 |
+
rect = pygame.Rect(self.x + self.zoom * (j + self.game.figure.x),
|
128 |
+
self.y + self.zoom * (i + self.game.figure.y),
|
129 |
+
self.zoom, self.zoom)
|
130 |
+
pygame.draw.rect(self.screen,
|
131 |
+
colors[self.game.figure.color],
|
132 |
+
rect.inflate(-2, -2))
|
133 |
+
|
134 |
+
# Convert Pygame surface to RGB array
|
135 |
+
return pygame.surfarray.array3d(self.screen)
|
136 |
+
|
137 |
+
elif mode == 'human':
|
138 |
+
if self.screen is None:
|
139 |
+
pygame.init()
|
140 |
+
size = (self.x * 2 + self.zoom * self.width, self.y * 2 + self.zoom * self.height)
|
141 |
+
self.screen = pygame.display.set_mode(size)
|
142 |
+
pygame.display.set_caption("Tetris RL")
|
143 |
+
|
144 |
+
self.screen.fill((173, 216, 230)) # WHITE background
|
145 |
+
|
146 |
+
# Draw the game field
|
147 |
+
for i in range(self.game.height):
|
148 |
+
for j in range(self.game.width):
|
149 |
+
rect = pygame.Rect(self.x + self.zoom * j, self.y + self.zoom * i, self.zoom, self.zoom)
|
150 |
+
pygame.draw.rect(self.screen, (128, 128, 128), rect, 1) # Grid lines
|
151 |
+
if self.game.field[i][j] > 0:
|
152 |
+
pygame.draw.rect(self.screen,
|
153 |
+
colors[self.game.field[i][j]],
|
154 |
+
rect.inflate(-2, -2))
|
155 |
+
|
156 |
+
# Draw the current figure
|
157 |
+
if self.game.figure is not None:
|
158 |
+
for i in range(4):
|
159 |
+
for j in range(4):
|
160 |
+
p = i * 4 + j
|
161 |
+
if p in self.game.figure.image():
|
162 |
+
rect = pygame.Rect(self.x + self.zoom * (j + self.game.figure.x),
|
163 |
+
self.y + self.zoom * (i + self.game.figure.y),
|
164 |
+
self.zoom, self.zoom)
|
165 |
+
pygame.draw.rect(self.screen,
|
166 |
+
colors[self.game.figure.color],
|
167 |
+
rect.inflate(-2, -2))
|
168 |
+
|
169 |
+
pygame.display.flip()
|
170 |
+
|
171 |
+
def close(self):
|
172 |
+
"""
|
173 |
+
Clean up resources
|
174 |
+
"""
|
175 |
+
if self.screen is not None:
|
176 |
+
pygame.display.quit()
|
177 |
+
pygame.quit()
|
178 |
+
|
179 |
+
# Reward shaping helper functions
|
180 |
+
def calculate_aggregate_height(self):
|
181 |
+
heights = [0 for _ in range(self.width)]
|
182 |
+
for j in range(self.width):
|
183 |
+
for i in range(self.height):
|
184 |
+
if self.game.field[i][j] != 0:
|
185 |
+
heights[j] = self.height - i
|
186 |
+
break
|
187 |
+
return sum(heights)
|
188 |
+
|
189 |
+
def calculate_holes(self):
|
190 |
+
holes = 0
|
191 |
+
for j in range(self.width):
|
192 |
+
block_found = False
|
193 |
+
for i in range(self.height):
|
194 |
+
if self.game.field[i][j] != 0:
|
195 |
+
block_found = True
|
196 |
+
elif block_found and self.game.field[i][j] == 0:
|
197 |
+
holes += 1
|
198 |
+
return holes
|
199 |
+
|
200 |
+
def calculate_bumpiness(self):
|
201 |
+
heights = [0 for _ in range(self.width)]
|
202 |
+
for j in range(self.width):
|
203 |
+
for i in range(self.height):
|
204 |
+
if self.game.field[i][j] != 0:
|
205 |
+
heights[j] = self.height - i
|
206 |
+
break
|
207 |
+
bumpiness = 0
|
208 |
+
for j in range(self.width - 1):
|
209 |
+
bumpiness += abs(heights[j] - heights[j + 1])
|
210 |
+
return bumpiness
|
train.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
from stable_baselines3 import DQN
|
3 |
+
from stable_baselines3.common.evaluation import evaluate_policy
|
4 |
+
from tetris_env import TetrisEnv
|
5 |
+
from callbacks import SaveFramesCallback
|
6 |
+
import os
|
7 |
+
|
8 |
+
def main():
|
9 |
+
# Create the environment
|
10 |
+
env = TetrisEnv()
|
11 |
+
|
12 |
+
# Initialize the RL model (DQN)
|
13 |
+
model = DQN('MlpPolicy', env, verbose=1,
|
14 |
+
learning_rate=1e-3,
|
15 |
+
buffer_size=50000,
|
16 |
+
learning_starts=1000,
|
17 |
+
batch_size=32,
|
18 |
+
gamma=0.99,
|
19 |
+
target_update_interval=1000,
|
20 |
+
exploration_fraction=0.1,
|
21 |
+
exploration_final_eps=0.02)
|
22 |
+
|
23 |
+
# Define the number of training timesteps
|
24 |
+
TIMESTEPS = 100000 # Adjust as needed
|
25 |
+
|
26 |
+
# Initialize the callback
|
27 |
+
callback = SaveFramesCallback(save_freq=5000, save_path="models/frames", verbose=1)
|
28 |
+
|
29 |
+
# Train the model with the callback
|
30 |
+
model.learn(total_timesteps=TIMESTEPS, callback=callback)
|
31 |
+
|
32 |
+
# Save the model
|
33 |
+
os.makedirs("models", exist_ok=True)
|
34 |
+
model.save("models/dqn_tetris")
|
35 |
+
print("Model saved to models/dqn_tetris.zip")
|
36 |
+
|
37 |
+
# Evaluate the trained agent
|
38 |
+
mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
|
39 |
+
print(f"Mean Reward: {mean_reward} +/- {std_reward}")
|
40 |
+
|
41 |
+
if __name__ == "__main__":
|
42 |
+
main()
|