Spaces:
Sleeping
Sleeping
valentin urena
commited on
Update chess_board.py
Browse files- chess_board.py +16 -20
chess_board.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
|
2 |
-
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
|
8 |
import chess
|
9 |
import chess.svg
|
@@ -17,10 +17,10 @@ class Game:
|
|
17 |
self.counter = 0
|
18 |
self.arrow= None
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
|
25 |
def compile_model(self):
|
26 |
self.model.compile(sampler=self.sampler)
|
@@ -37,9 +37,9 @@ class Game:
|
|
37 |
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
|
38 |
response="",)
|
39 |
|
40 |
-
|
41 |
|
42 |
-
|
43 |
|
44 |
if self.make_move(gemma_move):
|
45 |
print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
|
@@ -54,8 +54,7 @@ class Game:
|
|
54 |
return None
|
55 |
|
56 |
def gemma_moves(self):
|
57 |
-
|
58 |
-
# time.sleep(3)
|
59 |
if self.opening_moves and len(self.sequence)<len(self.opening_moves):
|
60 |
return self.call_gemma(self.opening_moves[len(self.sequence)])
|
61 |
else:
|
@@ -64,23 +63,20 @@ class Game:
|
|
64 |
def player_moves(self, move):
|
65 |
return self.make_move(move)
|
66 |
|
67 |
-
# Function to display the board
|
68 |
def display_board(self):
|
69 |
-
|
70 |
-
# display(SVG(chess.svg.board(board=self.board)))
|
71 |
if self.arrow:
|
72 |
board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
|
73 |
else:
|
74 |
board_svg = chess.svg.board(board=self.board)
|
75 |
-
# return svg2png(bytestring=board_svg)
|
76 |
return board_svg
|
77 |
|
78 |
-
|
79 |
def make_move(self, move):
|
|
|
80 |
try:
|
81 |
update = self.board.parse_san(move)
|
82 |
self.board.push(update)
|
83 |
-
# self.display_board()
|
84 |
self.sequence.append(move)
|
85 |
self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
|
86 |
return True
|
@@ -93,10 +89,10 @@ class Game:
|
|
93 |
self.sequence = []
|
94 |
self.counter = 0
|
95 |
self.arrow = None
|
96 |
-
# self.board.reset
|
97 |
return self.display_board()
|
98 |
|
99 |
def generate_moves(self, move):
|
|
|
100 |
valid_move = self.player_moves(move)
|
101 |
if valid_move:
|
102 |
yield self.display_board(), f"You played: {move}"
|
|
|
1 |
+
import os
|
2 |
+
os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
|
3 |
|
4 |
+
import keras_nlp
|
5 |
+
import keras
|
6 |
+
import torch
|
7 |
|
8 |
import chess
|
9 |
import chess.svg
|
|
|
17 |
self.counter = 0
|
18 |
self.arrow= None
|
19 |
|
20 |
+
self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
|
21 |
+
self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
|
22 |
+
self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
|
23 |
+
self.compile_model()
|
24 |
|
25 |
def compile_model(self):
|
26 |
self.model.compile(sampler=self.sampler)
|
|
|
37 |
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
|
38 |
response="",)
|
39 |
|
40 |
+
output = self.model.generate(prompt, max_length=256)
|
41 |
|
42 |
+
gemma_move = output.split(' ')[-1].strip("'")
|
43 |
|
44 |
if self.make_move(gemma_move):
|
45 |
print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
|
|
|
54 |
return None
|
55 |
|
56 |
def gemma_moves(self):
|
57 |
+
"""Calls Gemma to make a move, either self generated or from opening sequence"""
|
|
|
58 |
if self.opening_moves and len(self.sequence)<len(self.opening_moves):
|
59 |
return self.call_gemma(self.opening_moves[len(self.sequence)])
|
60 |
else:
|
|
|
63 |
def player_moves(self, move):
|
64 |
return self.make_move(move)
|
65 |
|
|
|
66 |
def display_board(self):
|
67 |
+
"""Return SVG image of board state"""
|
|
|
68 |
if self.arrow:
|
69 |
board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
|
70 |
else:
|
71 |
board_svg = chess.svg.board(board=self.board)
|
|
|
72 |
return board_svg
|
73 |
|
74 |
+
|
75 |
def make_move(self, move):
|
76 |
+
"""Checks to see if move is valid, if so pushes move to board state"""
|
77 |
try:
|
78 |
update = self.board.parse_san(move)
|
79 |
self.board.push(update)
|
|
|
80 |
self.sequence.append(move)
|
81 |
self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
|
82 |
return True
|
|
|
89 |
self.sequence = []
|
90 |
self.counter = 0
|
91 |
self.arrow = None
|
|
|
92 |
return self.display_board()
|
93 |
|
94 |
def generate_moves(self, move):
|
95 |
+
"""Generator function for one full turn of chess moves"""
|
96 |
valid_move = self.player_moves(move)
|
97 |
if valid_move:
|
98 |
yield self.display_board(), f"You played: {move}"
|