chess_llm_gemma / chess_board.py
valentin-ub's picture
Update chess_board.py
cd064af verified
# import spaces
import os
os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
import keras_nlp
import keras
import torch
import chess
import chess.svg
import time
class Game:
def __init__(self):
# Initialize the chess board
self.board = chess.Board()
self.sequence = []
self.counter = 0
self.arrow= None
self.opening_name = None
self.opening_moves = None
self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
self.compile_model()
def compile_model(self):
self.model.compile(sampler=self.sampler)
# @spaces.GPU
def inference_gemma(self, prompt, max_length=256):
"""Inference requires GPU"""
response = self.model.generate(prompt, max_length)
return response
def call_gemma(self, opening_move):
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
if opening_move:
gemma_move = opening_move
else:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
prompt = template.format(
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
response="",)
output = self.inference_gemma(prompt, max_length=256) #self.model.generate(prompt, max_length=256)
gemma_move = output.split(' ')[-1].strip("'")
if self.make_move(gemma_move):
print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
self.counter = 0
return self.display_board()
elif self.counter < 10:
self.counter += 1
print(self.counter)
return self.call_gemma()
else:
print("Gemma quit...")
return None
def gemma_moves(self):
"""Calls Gemma to make a move, either self generated or from opening sequence"""
if self.opening_moves and len(self.sequence)<len(self.opening_moves):
return self.call_gemma(self.opening_moves[len(self.sequence)])
else:
return self.call_gemma(None)
def player_moves(self, move):
return self.make_move(move)
def display_board(self):
"""Return SVG image of board state"""
if self.arrow:
board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
else:
board_svg = chess.svg.board(board=self.board)
return board_svg
def make_move(self, move):
"""Checks to see if move is valid, if so pushes move to board state"""
try:
update = self.board.parse_san(move)
self.board.push(update)
self.sequence.append(move)
self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
return True
except:
print(f"Invalid move '{move}'. Use algebraic notation (e.g., 'e4', 'Nf3', 'Bxc4') or ask Gemma for help.")
return False
def reset_board(self):
self.board = chess.Board()
self.sequence = []
self.counter = 0
self.arrow = None
return self.display_board()
def generate_moves(self, move):
"""Generator function for one full turn of chess moves"""
valid_move = self.player_moves(move)
if valid_move:
yield self.display_board(), f"You played: {move}"
time.sleep(2)
yield self.display_board(), f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})"
time.sleep(3)
yield self.gemma_moves(), f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})'
else:
print("Try again")
yield self.display_board(), "Try again"
def get_move_logs(self):
return self.sequence
def load_opening(self, opening_name, opening_moves):
self.opening_name = opening_name
self.opening_moves = opening_moves
return f"Ok, lets play the {opening_name}! {opening_moves} Make your first move."
def main():
end_game = False # Change this to False
play_match = Game()
play_match.display_board()
while end_game is False:
move = input("Your move (or 'No' to end game):")
if 'No' in move:
del play_match
end_game = True
else:
play_match.player_moves(move)
if __name__ == '__main__':
main()