# import spaces import os os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow" import keras_nlp import keras import torch import chess import chess.svg import time class Game: def __init__(self): # Initialize the chess board self.board = chess.Board() self.sequence = [] self.counter = 0 self.arrow= None self.opening_name = None self.opening_moves = None self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess' self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7) self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id) self.compile_model() def compile_model(self): self.model.compile(sampler=self.sampler) # @spaces.GPU def inference_gemma(self, prompt, max_length=256): """Inference requires GPU""" response = self.model.generate(prompt, max_length) return response def call_gemma(self, opening_move): template = "Instruction:\n{instruction}\n\nResponse:\n{response}" if opening_move: gemma_move = opening_move else: template = "Instruction:\n{instruction}\n\nResponse:\n{response}" prompt = template.format( instruction=f"Predict the next chess move in the sequence {str(self.sequence)}", response="",) output = self.inference_gemma(prompt, max_length=256) #self.model.generate(prompt, max_length=256) gemma_move = output.split(' ')[-1].strip("'") if self.make_move(gemma_move): print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})') self.counter = 0 return self.display_board() elif self.counter < 10: self.counter += 1 print(self.counter) return self.call_gemma() else: print("Gemma quit...") return None def gemma_moves(self): """Calls Gemma to make a move, either self generated or from opening sequence""" if self.opening_moves and len(self.sequence)