File size: 4,754 Bytes
cd064af
b1e1394
 
5138881
b1e1394
 
 
5138881
327d028
 
 
 
 
 
 
 
 
 
8e138a8
e916685
 
8e138a8
b1e1394
 
 
 
6f17538
 
 
86e4d05
cd064af
86e4d05
 
 
 
 
9c3e344
327d028
 
9c3e344
 
 
 
327d028
9c3e344
 
 
327d028
86e4d05
9c3e344
b1e1394
327d028
 
 
 
 
 
 
 
bb1246f
327d028
 
 
 
afaee35
b1e1394
9c3e344
 
 
 
327d028
 
 
 
 
b1e1394
06f7d3d
 
 
 
327d028
 
b1e1394
327d028
b1e1394
327d028
 
 
 
2927381
bdc377f
327d028
 
827af69
327d028
 
 
ab25196
 
bc1c948
327d028
 
afaee35
b1e1394
827af69
 
07e3a21
bc1c948
 
 
07e3a21
bdc377f
 
6c50855
327d028
95064f6
 
9c3e344
 
 
 
 
95064f6
 
327d028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# import spaces
import os
os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

import keras_nlp
import keras
import torch

import chess
import chess.svg
import time

class Game:
    def __init__(self):
        # Initialize the chess board
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        self.arrow= None
        self.opening_name = None
        self.opening_moves = None
        
        self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
        self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
        self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
        self.compile_model()

    def compile_model(self):
        self.model.compile(sampler=self.sampler)

    # @spaces.GPU
    def inference_gemma(self, prompt, max_length=256):
        """Inference requires GPU"""
        response = self.model.generate(prompt, max_length)
        return response
        
    def call_gemma(self, opening_move):
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        
        if opening_move:
            gemma_move = opening_move
        else:
            template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

            prompt = template.format(
                instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
                response="",)

            output = self.inference_gemma(prompt, max_length=256) #self.model.generate(prompt, max_length=256)
        
            gemma_move = output.split(' ')[-1].strip("'")

        if self.make_move(gemma_move):
            print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
            self.counter = 0
            return self.display_board()
        elif self.counter < 10:
            self.counter += 1
            print(self.counter)
            return self.call_gemma()
        else:
            print("Gemma quit...")
            return None

    def gemma_moves(self):
        """Calls Gemma to make a move, either self generated or from opening sequence"""
        if self.opening_moves and len(self.sequence)<len(self.opening_moves):
            return self.call_gemma(self.opening_moves[len(self.sequence)])
        else:
            return self.call_gemma(None)

    def player_moves(self, move):
        return self.make_move(move)

    def display_board(self):
        """Return SVG image of board state"""
        if self.arrow:
            board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
        else:
            board_svg = chess.svg.board(board=self.board)
        return board_svg

   
    def make_move(self, move):
        """Checks to see if move is valid, if so pushes move to board state"""
        try:
            update = self.board.parse_san(move)
            self.board.push(update)
            self.sequence.append(move)
            self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
            return True
        except:
            print(f"Invalid move '{move}'. Use algebraic notation (e.g., 'e4', 'Nf3', 'Bxc4') or ask Gemma for help.")
            return False
        
    def reset_board(self):
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        self.arrow = None
        return self.display_board()
    
    def generate_moves(self, move):
        """Generator function for one full turn of chess moves"""
        valid_move = self.player_moves(move)
        if valid_move:
            yield self.display_board(), f"You played: {move}"
            time.sleep(2)
            yield self.display_board(), f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})"
            time.sleep(3)
            yield self.gemma_moves(), f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})'
        else:
            print("Try again")
            yield self.display_board(), "Try again"

    def get_move_logs(self):
        return self.sequence
    
    def load_opening(self, opening_name, opening_moves):
        self.opening_name = opening_name
        self.opening_moves = opening_moves
        return f"Ok, lets play the {opening_name}! {opening_moves} Make your first move."
        

def main():
    end_game = False # Change this to False

    play_match = Game()
    play_match.display_board()

    while end_game is False:
        move = input("Your move (or 'No' to end game):")
        if 'No' in move:
            del play_match
            end_game = True
        else:
            play_match.player_moves(move)

if __name__ == '__main__':
    main()