File size: 3,169 Bytes
5138881
 
 
 
 
 
 
327d028
 
 
 
 
 
 
 
 
 
afaee35
33bfe50
6f17538
4727d72
6f17538
 
 
327d028
afaee35
327d028
 
 
 
 
 
 
f510d64
327d028
d6e4b42
327d028
d6e4b42
327d028
 
 
 
 
 
 
 
bb1246f
327d028
 
 
 
afaee35
327d028
 
afaee35
327d028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab25196
 
327d028
 
 
afaee35
327d028
0f0db04
327d028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# import os
# os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

import keras_nlp
import keras
import torch

import chess
import chess.svg
import time

class Game:
    def __init__(self):
        # Initialize the chess board
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
        self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
        self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
        self.compile_model()

    def compile_model(self):
        self.model.compile(sampler=self.sampler)
    
    def call_gemma(self):
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        

        prompt = template.format(
            instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
            response="",)

        output = self.model.generate(prompt, max_length=256)
       
        gemma_move = output.split(' ')[-1].strip("'")

        # gemma_move = 'e5'

        if self.make_move(gemma_move):
            print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
            self.counter = 0
            return self.display_board()
        elif self.counter < 10:
            self.counter += 1
            print(self.counter)
            return self.call_gemma()
        else:
            print("Gemma quit...")
            return None

    def gemma_moves(self):
        print(f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})")
        time.sleep(3)
        return self.call_gemma()

    def player_moves(self, move):
        return self.make_move(move)

    # Function to display the board
    def display_board(self):
        # clear_output(wait=True)
        # display(SVG(chess.svg.board(board=self.board)))
        board_svg = chess.svg.board(board=self.board)
        # return svg2png(bytestring=board_svg)
        return board_svg

    # Function to make a move
    def make_move(self, move):
        try:
            update = self.board.parse_san(move)
            self.board.push(update)
            # self.display_board()
            self.sequence.append(move)
            return self.display_board()
        except:
            print(f"Invalid move '{move}'. Use algebraic notation (e.g., 'e4', 'Nf3', 'Bxc4') or ask Gemma for help.")
            return None
        
    def reset_board(self):
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        # self.board.reset
        return self.display_board()
    
    def generate_moves(self, move):
        yield self.player_moves(move)
        yield self.gemma_moves()

def main():
    end_game = False # Change this to False

    play_match = Game()
    play_match.display_board()

    while end_game is False:
        move = input("Your move (or 'No' to end game):")
        if 'No' in move:
            del play_match
            end_game = True
        else:
            play_match.player_moves(move)

if __name__ == '__main__':
    main()