Spaces:
Sleeping
Sleeping
File size: 4,754 Bytes
cd064af b1e1394 5138881 b1e1394 5138881 327d028 8e138a8 e916685 8e138a8 b1e1394 6f17538 86e4d05 cd064af 86e4d05 9c3e344 327d028 9c3e344 327d028 9c3e344 327d028 86e4d05 9c3e344 b1e1394 327d028 bb1246f 327d028 afaee35 b1e1394 9c3e344 327d028 b1e1394 06f7d3d 327d028 b1e1394 327d028 b1e1394 327d028 2927381 bdc377f 327d028 827af69 327d028 ab25196 bc1c948 327d028 afaee35 b1e1394 827af69 07e3a21 bc1c948 07e3a21 bdc377f 6c50855 327d028 95064f6 9c3e344 95064f6 327d028 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# import spaces
import os
os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
import keras_nlp
import keras
import torch
import chess
import chess.svg
import time
class Game:
def __init__(self):
# Initialize the chess board
self.board = chess.Board()
self.sequence = []
self.counter = 0
self.arrow= None
self.opening_name = None
self.opening_moves = None
self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
self.compile_model()
def compile_model(self):
self.model.compile(sampler=self.sampler)
# @spaces.GPU
def inference_gemma(self, prompt, max_length=256):
"""Inference requires GPU"""
response = self.model.generate(prompt, max_length)
return response
def call_gemma(self, opening_move):
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
if opening_move:
gemma_move = opening_move
else:
template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
prompt = template.format(
instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
response="",)
output = self.inference_gemma(prompt, max_length=256) #self.model.generate(prompt, max_length=256)
gemma_move = output.split(' ')[-1].strip("'")
if self.make_move(gemma_move):
print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
self.counter = 0
return self.display_board()
elif self.counter < 10:
self.counter += 1
print(self.counter)
return self.call_gemma()
else:
print("Gemma quit...")
return None
def gemma_moves(self):
"""Calls Gemma to make a move, either self generated or from opening sequence"""
if self.opening_moves and len(self.sequence)<len(self.opening_moves):
return self.call_gemma(self.opening_moves[len(self.sequence)])
else:
return self.call_gemma(None)
def player_moves(self, move):
return self.make_move(move)
def display_board(self):
"""Return SVG image of board state"""
if self.arrow:
board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
else:
board_svg = chess.svg.board(board=self.board)
return board_svg
def make_move(self, move):
"""Checks to see if move is valid, if so pushes move to board state"""
try:
update = self.board.parse_san(move)
self.board.push(update)
self.sequence.append(move)
self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
return True
except:
print(f"Invalid move '{move}'. Use algebraic notation (e.g., 'e4', 'Nf3', 'Bxc4') or ask Gemma for help.")
return False
def reset_board(self):
self.board = chess.Board()
self.sequence = []
self.counter = 0
self.arrow = None
return self.display_board()
def generate_moves(self, move):
"""Generator function for one full turn of chess moves"""
valid_move = self.player_moves(move)
if valid_move:
yield self.display_board(), f"You played: {move}"
time.sleep(2)
yield self.display_board(), f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})"
time.sleep(3)
yield self.gemma_moves(), f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})'
else:
print("Try again")
yield self.display_board(), "Try again"
def get_move_logs(self):
return self.sequence
def load_opening(self, opening_name, opening_moves):
self.opening_name = opening_name
self.opening_moves = opening_moves
return f"Ok, lets play the {opening_name}! {opening_moves} Make your first move."
def main():
end_game = False # Change this to False
play_match = Game()
play_match.display_board()
while end_game is False:
move = input("Your move (or 'No' to end game):")
if 'No' in move:
del play_match
end_game = True
else:
play_match.player_moves(move)
if __name__ == '__main__':
main() |