valentin urena commited on
Commit
b1e1394
·
verified ·
1 Parent(s): 003246a

Update chess_board.py

Browse files
Files changed (1) hide show
  1. chess_board.py +16 -20
chess_board.py CHANGED
@@ -1,9 +1,9 @@
1
- # import os
2
- # os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
3
 
4
- # import keras_nlp
5
- # import keras
6
- # import torch
7
 
8
  import chess
9
  import chess.svg
@@ -17,10 +17,10 @@ class Game:
17
  self.counter = 0
18
  self.arrow= None
19
 
20
- # self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
21
- # self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
22
- # self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
23
- # self.compile_model()
24
 
25
  def compile_model(self):
26
  self.model.compile(sampler=self.sampler)
@@ -37,9 +37,9 @@ class Game:
37
  instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
38
  response="",)
39
 
40
- # output = self.model.generate(prompt, max_length=256)
41
 
42
- # gemma_move = output.split(' ')[-1].strip("'")
43
 
44
  if self.make_move(gemma_move):
45
  print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
@@ -54,8 +54,7 @@ class Game:
54
  return None
55
 
56
  def gemma_moves(self):
57
- # print(f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})")
58
- # time.sleep(3)
59
  if self.opening_moves and len(self.sequence)<len(self.opening_moves):
60
  return self.call_gemma(self.opening_moves[len(self.sequence)])
61
  else:
@@ -64,23 +63,20 @@ class Game:
64
  def player_moves(self, move):
65
  return self.make_move(move)
66
 
67
- # Function to display the board
68
  def display_board(self):
69
- # clear_output(wait=True)
70
- # display(SVG(chess.svg.board(board=self.board)))
71
  if self.arrow:
72
  board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
73
  else:
74
  board_svg = chess.svg.board(board=self.board)
75
- # return svg2png(bytestring=board_svg)
76
  return board_svg
77
 
78
- # Function to make a move
79
  def make_move(self, move):
 
80
  try:
81
  update = self.board.parse_san(move)
82
  self.board.push(update)
83
- # self.display_board()
84
  self.sequence.append(move)
85
  self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
86
  return True
@@ -93,10 +89,10 @@ class Game:
93
  self.sequence = []
94
  self.counter = 0
95
  self.arrow = None
96
- # self.board.reset
97
  return self.display_board()
98
 
99
  def generate_moves(self, move):
 
100
  valid_move = self.player_moves(move)
101
  if valid_move:
102
  yield self.display_board(), f"You played: {move}"
 
1
+ import os
2
+ os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
3
 
4
+ import keras_nlp
5
+ import keras
6
+ import torch
7
 
8
  import chess
9
  import chess.svg
 
17
  self.counter = 0
18
  self.arrow= None
19
 
20
+ self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
21
+ self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
22
+ self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
23
+ self.compile_model()
24
 
25
  def compile_model(self):
26
  self.model.compile(sampler=self.sampler)
 
37
  instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
38
  response="",)
39
 
40
+ output = self.model.generate(prompt, max_length=256)
41
 
42
+ gemma_move = output.split(' ')[-1].strip("'")
43
 
44
  if self.make_move(gemma_move):
45
  print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
 
54
  return None
55
 
56
  def gemma_moves(self):
57
+ """Calls Gemma to make a move, either self generated or from opening sequence"""
 
58
  if self.opening_moves and len(self.sequence)<len(self.opening_moves):
59
  return self.call_gemma(self.opening_moves[len(self.sequence)])
60
  else:
 
63
  def player_moves(self, move):
64
  return self.make_move(move)
65
 
 
66
  def display_board(self):
67
+ """Return SVG image of board state"""
 
68
  if self.arrow:
69
  board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
70
  else:
71
  board_svg = chess.svg.board(board=self.board)
 
72
  return board_svg
73
 
74
+
75
  def make_move(self, move):
76
+ """Checks to see if move is valid, if so pushes move to board state"""
77
  try:
78
  update = self.board.parse_san(move)
79
  self.board.push(update)
 
80
  self.sequence.append(move)
81
  self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
82
  return True
 
89
  self.sequence = []
90
  self.counter = 0
91
  self.arrow = None
 
92
  return self.display_board()
93
 
94
  def generate_moves(self, move):
95
+ """Generator function for one full turn of chess moves"""
96
  valid_move = self.player_moves(move)
97
  if valid_move:
98
  yield self.display_board(), f"You played: {move}"