Spaces:
Sleeping
Sleeping
File size: 3,496 Bytes
5138881 327d028 afaee35 33bfe50 6f17538 4727d72 6f17538 327d028 afaee35 327d028 f510d64 327d028 d6e4b42 327d028 d6e4b42 327d028 bb1246f 327d028 afaee35 327d028 afaee35 327d028 bdc377f 327d028 827af69 327d028 ab25196 327d028 afaee35 827af69 07e3a21 bdc377f 827af69 327d028 95064f6 327d028 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
# import os
# os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
import keras_nlp
import keras
import torch
import chess
import chess.svg
import time
class Game:
def __init__(self):
# Initialize the chess board
self.board = chess.Board()
self.sequence = []
self.counter = 0
self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
self.compile_model()
def compile_model(self):
self.model.compile(sampler=self.sampler)
def call_gemma(self):
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
prompt = template.format(
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
response="",)
output = self.model.generate(prompt, max_length=256)
gemma_move = output.split(' ')[-1].strip("'")
# gemma_move = 'e5'
if self.make_move(gemma_move):
print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
self.counter = 0
return self.display_board()
elif self.counter < 10:
self.counter += 1
print(self.counter)
return self.call_gemma()
else:
print("Gemma quit...")
return None
def gemma_moves(self):
print(f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})")
time.sleep(3)
return self.call_gemma()
def player_moves(self, move):
return self.make_move(move)
# Function to display the board
def display_board(self):
# clear_output(wait=True)
# display(SVG(chess.svg.board(board=self.board)))
board_svg = chess.svg.board(board=self.board)
# return svg2png(bytestring=board_svg)
return board_svg
# Function to make a move
def make_move(self, move):
try:
update = self.board.parse_san(move)
self.board.push(update)
# self.display_board()
self.sequence.append(move)
return True
except:
print(f"Invalid move '{move}'. Use algebraic notation (e.g., 'e4', 'Nf3', 'Bxc4') or ask Gemma for help.")
return False
def reset_board(self):
self.board = chess.Board()
self.sequence = []
self.counter = 0
# self.board.reset
return self.display_board()
def generate_moves(self, move):
valid_move = self.player_moves(move)
if valid_move:
yield self.display_board(), f"You played: {move}"
yield self.gemma_moves(), f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})'
else:
print("Try again")
yield self.display_board()
def get_move_logs(self):
return self.sequence
def main():
end_game = False # Change this to False
play_match = Game()
play_match.display_board()
while end_game is False:
move = input("Your move (or 'No' to end game):")
if 'No' in move:
del play_match
end_game = True
else:
play_match.player_moves(move)
if __name__ == '__main__':
main() |