BaljinderH commited on
Commit
b100cf9
·
verified ·
1 Parent(s): 97b342e

Upload 8 files

Browse files
Files changed (8) hide show
  1. callbacks.py +32 -0
  2. compile_video.py +16 -0
  3. evaluate.py +41 -0
  4. push_to_hub.py +39 -0
  5. requirements.txt +8 -0
  6. sandtris.py +123 -0
  7. tetris_env.py +210 -0
  8. train.py +42 -0
callbacks.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from stable_baselines3.common.callbacks import BaseCallback
3
+ import imageio
4
+ import numpy as np
5
+
6
+ class SaveFramesCallback(BaseCallback):
7
+ """
8
+ Callback for saving frames during training.
9
+ """
10
+ def __init__(self, save_freq, save_path, verbose=0):
11
+ super(SaveFramesCallback, self).__init__(verbose)
12
+ self.save_freq = save_freq
13
+ self.save_path = save_path
14
+ self.frames = []
15
+ os.makedirs(self.save_path, exist_ok=True)
16
+
17
+ def _on_step(self) -> bool:
18
+ if self.num_timesteps % self.save_freq == 0:
19
+ # Render the environment and get the RGB array
20
+ frame = self.training_env.render(mode='rgb_array')
21
+ self.frames.append(frame)
22
+ if self.verbose > 0:
23
+ print(f"Saved frame at timestep {self.num_timesteps}")
24
+ return True
25
+
26
+ def _on_training_end(self) -> None:
27
+ # Save the frames as a GIF
28
+ if self.frames:
29
+ gif_path = os.path.join(self.save_path, "training.gif")
30
+ imageio.mimsave(gif_path, self.frames, fps=10)
31
+ if self.verbose > 0:
32
+ print(f"Saved training GIF to {gif_path}")
compile_video.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imageio
2
+ import os
3
+
4
+ def compile_gif_to_video(gif_path, video_path):
5
+ reader = imageio.get_reader(gif_path)
6
+ fps = 10
7
+ writer = imageio.get_writer(video_path, fps=fps)
8
+ for frame in reader:
9
+ writer.append_data(frame)
10
+ writer.close()
11
+ print(f"Video saved to {video_path}")
12
+
13
+ if __name__ == "__main__":
14
+ gif_path = "models/frames/training.gif"
15
+ video_path = "models/training.mp4"
16
+ compile_gif_to_video(gif_path, video_path)
evaluate.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ from stable_baselines3 import DQN
3
+ from tetris_env import TetrisEnv
4
+ import pygame
5
+ import time
6
+
7
+ def main():
8
+ # Create the environment
9
+ env = TetrisEnv()
10
+
11
+ # Load the trained model
12
+ model = DQN.load("models/dqn_tetris")
13
+
14
+ # Number of evaluation episodes
15
+ episodes = 5
16
+
17
+ for ep in range(1, episodes + 1):
18
+ obs = env.reset()
19
+ done = False
20
+ total_reward = 0
21
+ while not done:
22
+ # Render the game (optional)
23
+ env.render(mode='human')
24
+
25
+ # Predict the action using the trained model
26
+ action, _states = model.predict(obs, deterministic=True)
27
+
28
+ # Take the action in the environment
29
+ obs, reward, done, info = env.step(action)
30
+
31
+ total_reward += reward
32
+
33
+ # Control the rendering speed
34
+ pygame.time.wait(100) # Wait 100 ms between steps
35
+
36
+ print(f"Episode {ep}: Total Reward = {total_reward}")
37
+
38
+ env.close()
39
+
40
+ if __name__ == "__main__":
41
+ main()
push_to_hub.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from stable_baselines3 import DQN
2
+ from huggingface_hub import HfApi, HfFolder, Repository
3
+ import os
4
+
5
+ def main():
6
+ # Define repository details
7
+ repo_name = "dqn-tetris"
8
+ model_path = "models/dqn_tetris.zip"
9
+ model_dir = "models"
10
+
11
+ # Load the trained model
12
+ model = DQN.load(model_path)
13
+
14
+ # Initialize Hugging Face API
15
+ api = HfApi()
16
+ user = api.whoami()["name"]
17
+
18
+ # Create the repository if it doesn't exist
19
+ try:
20
+ api.create_repo(name=repo_name, repo_type="model", exist_ok=True)
21
+ print(f"Repository '{repo_name}' created.")
22
+ except Exception as e:
23
+ print(f"Repository '{repo_name}' already exists or failed to create: {e}")
24
+
25
+ # Clone the repository locally
26
+ repo = Repository(local_dir=repo_name, clone_from=f"{user}/{repo_name}", use_auth_token=True)
27
+
28
+ # Copy the model file into the repository directory
29
+ os.makedirs(repo_name, exist_ok=True)
30
+ os.rename(model_path, os.path.join(repo_name, "dqn_tetris.zip"))
31
+
32
+ # Add and commit the model
33
+ repo.git_add(auto_lfs_track=True)
34
+ repo.git_commit("Add trained DQN Tetris model")
35
+ repo.git_push()
36
+ print(f"Model pushed to Hugging Face Hub at {user}/{repo_name}")
37
+
38
+ if __name__ == "__main__":
39
+ main()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pygame
2
+ gym
3
+ stable-baselines3
4
+ numpy
5
+ torch
6
+ huggingface-hub
7
+ imageio
8
+ gradio
sandtris.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pygame
2
+ import random
3
+
4
+ colors = [
5
+ (0, 0, 0),
6
+ (120, 37, 179),
7
+ (100, 179, 179),
8
+ (80, 34, 22),
9
+ (80, 134, 22),
10
+ (180, 34, 22),
11
+ (180, 34, 122),
12
+ ]
13
+
14
+ class Figure:
15
+ x = 0
16
+ y = 0
17
+
18
+ figures = [
19
+ [[1, 5, 9, 13], [4, 5, 6, 7]],
20
+ [[4, 5, 9, 10], [2, 6, 5, 9]],
21
+ [[6, 7, 9, 10], [1, 5, 6, 10]],
22
+ [[1, 2, 5, 9], [0, 4, 5, 6], [1, 5, 9, 8], [4, 5, 6, 10]],
23
+ [[1, 2, 6, 10], [5, 6, 7, 9], [2, 6, 10, 11], [3, 5, 6, 7]],
24
+ [[1, 4, 5, 6], [1, 4, 5, 9], [4, 5, 6, 9], [1, 5, 6, 9]],
25
+ [[1, 2, 5, 6]],
26
+ ]
27
+
28
+ def __init__(self, x, y):
29
+ self.x = x
30
+ self.y = y
31
+ self.type = random.randint(0, len(self.figures) - 1)
32
+ self.color = random.randint(1, len(colors) - 1)
33
+ self.rotation = 0
34
+
35
+ def image(self):
36
+ return self.figures[self.type][self.rotation]
37
+
38
+ def rotate(self):
39
+ self.rotation = (self.rotation + 1) % len(self.figures[self.type])
40
+
41
+ class Tetris:
42
+ def __init__(self, height, width):
43
+ self.level = 2
44
+ self.score = 0
45
+ self.state = "start"
46
+ self.field = []
47
+ self.height = height
48
+ self.width = width
49
+ self.x = 100
50
+ self.y = 60
51
+ self.zoom = 20
52
+ self.figure = None
53
+
54
+ for i in range(height):
55
+ new_line = []
56
+ for j in range(width):
57
+ new_line.append(0)
58
+ self.field.append(new_line)
59
+
60
+ def new_figure(self):
61
+ self.figure = Figure(3, 0)
62
+
63
+ def intersects(self):
64
+ intersection = False
65
+ for i in range(4):
66
+ for j in range(4):
67
+ if i * 4 + j in self.figure.image():
68
+ if (
69
+ i + self.figure.y > self.height - 1
70
+ or j + self.figure.x > self.width - 1
71
+ or j + self.figure.x < 0
72
+ or self.field[i + self.figure.y][j + self.figure.x] > 0
73
+ ):
74
+ intersection = True
75
+ return intersection
76
+
77
+ def break_lines(self):
78
+ lines = 0
79
+ for i in range(1, self.height):
80
+ zeros = 0
81
+ for j in range(self.width):
82
+ if self.field[i][j] == 0:
83
+ zeros += 1
84
+ if zeros == 0:
85
+ lines += 1
86
+ for i1 in range(i, 1, -1):
87
+ for j in range(self.width):
88
+ self.field[i1][j] = self.field[i1 - 1][j]
89
+ self.score += lines ** 2
90
+
91
+ def go_space(self):
92
+ while not self.intersects():
93
+ self.figure.y += 1
94
+ self.figure.y -= 1
95
+ self.freeze()
96
+
97
+ def go_down(self):
98
+ self.figure.y += 1
99
+ if self.intersects():
100
+ self.figure.y -= 1
101
+ self.freeze()
102
+
103
+ def freeze(self):
104
+ for i in range(4):
105
+ for j in range(4):
106
+ if i * 4 + j in self.figure.image():
107
+ self.field[i + self.figure.y][j + self.figure.x] = self.figure.color
108
+ self.break_lines()
109
+ self.new_figure()
110
+ if self.intersects():
111
+ self.state = "gameover"
112
+
113
+ def go_side(self, dx):
114
+ old_x = self.figure.x
115
+ self.figure.x += dx
116
+ if self.intersects():
117
+ self.figure.x = old_x
118
+
119
+ def rotate(self):
120
+ old_rotation = self.figure.rotation
121
+ self.figure.rotate()
122
+ if self.intersects():
123
+ self.figure.rotation = old_rotation
tetris_env.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ from gym import spaces
3
+ import numpy as np
4
+ import pygame
5
+ import random
6
+ from sandtris import Tetris # Ensure sandtris.py is in the same directory
7
+ import os
8
+
9
+ colors = [
10
+ (0, 0, 0),
11
+ (120, 37, 179),
12
+ (100, 179, 179),
13
+ (80, 34, 22),
14
+ (80, 134, 22),
15
+ (180, 34, 22),
16
+ (180, 34, 122),
17
+ ]
18
+
19
+ class TetrisEnv(gym.Env):
20
+ """
21
+ Custom Environment for Tetris game compatible with OpenAI Gym
22
+ """
23
+ metadata = {'render.modes': ['human', 'rgb_array']}
24
+
25
+ def __init__(self):
26
+ super(TetrisEnv, self).__init__()
27
+
28
+ # Define action space: 0=left, 1=right, 2=rotate, 3=drop, 4=noop
29
+ self.action_space = spaces.Discrete(5)
30
+
31
+ # Observation space: 2D grid representing the game board
32
+ self.height = 20
33
+ self.width = 10
34
+ self.observation_space = spaces.Box(low=0, high=6,
35
+ shape=(self.height, self.width), dtype=np.int32)
36
+
37
+ # Initialize the game
38
+ self.game = Tetris(self.height, self.width)
39
+
40
+ # Setup for rendering
41
+ self.screen = None
42
+ self.zoom = 20
43
+ self.x = 100
44
+ self.y = 60
45
+
46
+ def reset(self):
47
+ """
48
+ Reset the game to initial state
49
+ """
50
+ self.game = Tetris(self.height, self.width)
51
+ self.game.new_figure()
52
+ return self._get_obs()
53
+
54
+ def step(self, action):
55
+ """
56
+ Execute one time step within the environment
57
+ """
58
+ done = False
59
+ reward = 0
60
+
61
+ # Apply action
62
+ if action == 0:
63
+ self.game.go_side(-1) # Move left
64
+ elif action == 1:
65
+ self.game.go_side(1) # Move right
66
+ elif action == 2:
67
+ self.game.rotate() # Rotate
68
+ elif action == 3:
69
+ self.game.go_space() # Drop
70
+ elif action == 4:
71
+ pass # No operation
72
+
73
+ # Move the piece down automatically
74
+ self.game.go_down()
75
+
76
+ # Calculate reward
77
+ lines_cleared = self.game.score # Assuming score increments with lines cleared
78
+ reward += lines_cleared * 10
79
+
80
+ # Additional reward shaping
81
+ aggregate_height = self.calculate_aggregate_height()
82
+ holes = self.calculate_holes()
83
+ bumpiness = self.calculate_bumpiness()
84
+
85
+ reward -= (aggregate_height * 0.5 + holes * 0.7 + bumpiness * 0.3)
86
+
87
+ if self.game.state == "gameover":
88
+ done = True
89
+ reward -= 10 # Penalty for losing
90
+
91
+ return self._get_obs(), reward, done, {}
92
+
93
+ def _get_obs(self):
94
+ """
95
+ Get the current state of the game as an observation
96
+ """
97
+ return np.array(self.game.field)
98
+
99
+ def render(self, mode='human'):
100
+ """
101
+ Render the game state as an RGB array or display it
102
+ """
103
+ if mode == 'rgb_array':
104
+ if self.screen is None:
105
+ pygame.init()
106
+ size = (self.x * 2 + self.zoom * self.width, self.y * 2 + self.zoom * self.height)
107
+ self.screen = pygame.Surface(size)
108
+
109
+ self.screen.fill((173, 216, 230)) # WHITE background
110
+
111
+ # Draw the game field
112
+ for i in range(self.game.height):
113
+ for j in range(self.game.width):
114
+ rect = pygame.Rect(self.x + self.zoom * j, self.y + self.zoom * i, self.zoom, self.zoom)
115
+ pygame.draw.rect(self.screen, (128, 128, 128), rect, 1) # Grid lines
116
+ if self.game.field[i][j] > 0:
117
+ pygame.draw.rect(self.screen,
118
+ colors[self.game.field[i][j]],
119
+ rect.inflate(-2, -2))
120
+
121
+ # Draw the current figure
122
+ if self.game.figure is not None:
123
+ for i in range(4):
124
+ for j in range(4):
125
+ p = i * 4 + j
126
+ if p in self.game.figure.image():
127
+ rect = pygame.Rect(self.x + self.zoom * (j + self.game.figure.x),
128
+ self.y + self.zoom * (i + self.game.figure.y),
129
+ self.zoom, self.zoom)
130
+ pygame.draw.rect(self.screen,
131
+ colors[self.game.figure.color],
132
+ rect.inflate(-2, -2))
133
+
134
+ # Convert Pygame surface to RGB array
135
+ return pygame.surfarray.array3d(self.screen)
136
+
137
+ elif mode == 'human':
138
+ if self.screen is None:
139
+ pygame.init()
140
+ size = (self.x * 2 + self.zoom * self.width, self.y * 2 + self.zoom * self.height)
141
+ self.screen = pygame.display.set_mode(size)
142
+ pygame.display.set_caption("Tetris RL")
143
+
144
+ self.screen.fill((173, 216, 230)) # WHITE background
145
+
146
+ # Draw the game field
147
+ for i in range(self.game.height):
148
+ for j in range(self.game.width):
149
+ rect = pygame.Rect(self.x + self.zoom * j, self.y + self.zoom * i, self.zoom, self.zoom)
150
+ pygame.draw.rect(self.screen, (128, 128, 128), rect, 1) # Grid lines
151
+ if self.game.field[i][j] > 0:
152
+ pygame.draw.rect(self.screen,
153
+ colors[self.game.field[i][j]],
154
+ rect.inflate(-2, -2))
155
+
156
+ # Draw the current figure
157
+ if self.game.figure is not None:
158
+ for i in range(4):
159
+ for j in range(4):
160
+ p = i * 4 + j
161
+ if p in self.game.figure.image():
162
+ rect = pygame.Rect(self.x + self.zoom * (j + self.game.figure.x),
163
+ self.y + self.zoom * (i + self.game.figure.y),
164
+ self.zoom, self.zoom)
165
+ pygame.draw.rect(self.screen,
166
+ colors[self.game.figure.color],
167
+ rect.inflate(-2, -2))
168
+
169
+ pygame.display.flip()
170
+
171
+ def close(self):
172
+ """
173
+ Clean up resources
174
+ """
175
+ if self.screen is not None:
176
+ pygame.display.quit()
177
+ pygame.quit()
178
+
179
+ # Reward shaping helper functions
180
+ def calculate_aggregate_height(self):
181
+ heights = [0 for _ in range(self.width)]
182
+ for j in range(self.width):
183
+ for i in range(self.height):
184
+ if self.game.field[i][j] != 0:
185
+ heights[j] = self.height - i
186
+ break
187
+ return sum(heights)
188
+
189
+ def calculate_holes(self):
190
+ holes = 0
191
+ for j in range(self.width):
192
+ block_found = False
193
+ for i in range(self.height):
194
+ if self.game.field[i][j] != 0:
195
+ block_found = True
196
+ elif block_found and self.game.field[i][j] == 0:
197
+ holes += 1
198
+ return holes
199
+
200
+ def calculate_bumpiness(self):
201
+ heights = [0 for _ in range(self.width)]
202
+ for j in range(self.width):
203
+ for i in range(self.height):
204
+ if self.game.field[i][j] != 0:
205
+ heights[j] = self.height - i
206
+ break
207
+ bumpiness = 0
208
+ for j in range(self.width - 1):
209
+ bumpiness += abs(heights[j] - heights[j + 1])
210
+ return bumpiness
train.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ from stable_baselines3 import DQN
3
+ from stable_baselines3.common.evaluation import evaluate_policy
4
+ from tetris_env import TetrisEnv
5
+ from callbacks import SaveFramesCallback
6
+ import os
7
+
8
+ def main():
9
+ # Create the environment
10
+ env = TetrisEnv()
11
+
12
+ # Initialize the RL model (DQN)
13
+ model = DQN('MlpPolicy', env, verbose=1,
14
+ learning_rate=1e-3,
15
+ buffer_size=50000,
16
+ learning_starts=1000,
17
+ batch_size=32,
18
+ gamma=0.99,
19
+ target_update_interval=1000,
20
+ exploration_fraction=0.1,
21
+ exploration_final_eps=0.02)
22
+
23
+ # Define the number of training timesteps
24
+ TIMESTEPS = 100000 # Adjust as needed
25
+
26
+ # Initialize the callback
27
+ callback = SaveFramesCallback(save_freq=5000, save_path="models/frames", verbose=1)
28
+
29
+ # Train the model with the callback
30
+ model.learn(total_timesteps=TIMESTEPS, callback=callback)
31
+
32
+ # Save the model
33
+ os.makedirs("models", exist_ok=True)
34
+ model.save("models/dqn_tetris")
35
+ print("Model saved to models/dqn_tetris.zip")
36
+
37
+ # Evaluate the trained agent
38
+ mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
39
+ print(f"Mean Reward: {mean_reward} +/- {std_reward}")
40
+
41
+ if __name__ == "__main__":
42
+ main()