File size: 3,496 Bytes
5138881
 
 
 
 
 
 
327d028
 
 
 
 
 
 
 
 
 
afaee35
33bfe50
6f17538
4727d72
6f17538
 
 
327d028
afaee35
327d028
 
 
 
 
 
 
f510d64
327d028
d6e4b42
327d028
d6e4b42
327d028
 
 
 
 
 
 
 
bb1246f
327d028
 
 
 
afaee35
327d028
 
afaee35
327d028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bdc377f
327d028
 
827af69
327d028
 
 
ab25196
 
327d028
 
 
afaee35
827af69
 
07e3a21
 
bdc377f
 
827af69
327d028
95064f6
 
 
 
327d028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# import os
# os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

import keras_nlp
import keras
import torch

import chess
import chess.svg
import time

class Game:
    def __init__(self):
        # Initialize the chess board
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
        self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
        self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
        self.compile_model()

    def compile_model(self):
        self.model.compile(sampler=self.sampler)
    
    def call_gemma(self):
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        

        prompt = template.format(
            instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
            response="",)

        output = self.model.generate(prompt, max_length=256)
       
        gemma_move = output.split(' ')[-1].strip("'")

        # gemma_move = 'e5'

        if self.make_move(gemma_move):
            print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
            self.counter = 0
            return self.display_board()
        elif self.counter < 10:
            self.counter += 1
            print(self.counter)
            return self.call_gemma()
        else:
            print("Gemma quit...")
            return None

    def gemma_moves(self):
        print(f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})")
        time.sleep(3)
        return self.call_gemma()

    def player_moves(self, move):
        return self.make_move(move)

    # Function to display the board
    def display_board(self):
        # clear_output(wait=True)
        # display(SVG(chess.svg.board(board=self.board)))
        board_svg = chess.svg.board(board=self.board)
        # return svg2png(bytestring=board_svg)
        return board_svg

    # Function to make a move
    def make_move(self, move):
        try:
            update = self.board.parse_san(move)
            self.board.push(update)
            # self.display_board()
            self.sequence.append(move)
            return True
        except:
            print(f"Invalid move '{move}'. Use algebraic notation (e.g., 'e4', 'Nf3', 'Bxc4') or ask Gemma for help.")
            return False
        
    def reset_board(self):
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        # self.board.reset
        return self.display_board()
    
    def generate_moves(self, move):
        valid_move = self.player_moves(move)
        if valid_move:
            yield self.display_board(), f"You played: {move}"
            yield self.gemma_moves(), f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})'
        else:
            print("Try again")
            yield self.display_board()

    def get_move_logs(self):
        return self.sequence
        

def main():
    end_game = False # Change this to False

    play_match = Game()
    play_match.display_board()

    while end_game is False:
        move = input("Your move (or 'No' to end game):")
        if 'No' in move:
            del play_match
            end_game = True
        else:
            play_match.player_moves(move)

if __name__ == '__main__':
    main()