File size: 4,581 Bytes
5138881
 
 
9c3e344
 
 
5138881
327d028
 
 
 
 
 
 
 
 
 
8e138a8
 
9c3e344
 
 
 
6f17538
 
 
327d028
9c3e344
327d028
 
9c3e344
 
 
 
327d028
9c3e344
 
 
327d028
9c3e344
 
 
327d028
 
 
 
 
 
 
 
bb1246f
327d028
 
 
 
afaee35
bc1c948
 
9c3e344
 
 
 
327d028
 
 
 
 
 
 
 
06f7d3d
 
 
 
327d028
 
 
 
 
 
 
 
 
 
2927381
bdc377f
327d028
 
827af69
327d028
 
 
ab25196
 
bc1c948
327d028
 
 
afaee35
827af69
 
07e3a21
bc1c948
 
 
07e3a21
bdc377f
 
6c50855
327d028
95064f6
 
9c3e344
 
 
 
 
 
95064f6
 
327d028
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
# import os
# os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

# import keras_nlp
# import keras
# import torch

import chess
import chess.svg
import time

class Game:
    def __init__(self):
        # Initialize the chess board
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        self.arrow= None
        
        # self.model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
        # self.sampler = keras_nlp.samplers.TopKSampler(k=50, temperature=0.7)
        # self.model = keras_nlp.models.GemmaCausalLM.from_preset(self.model_id)
        # self.compile_model()

    def compile_model(self):
        self.model.compile(sampler=self.sampler)
    
    def call_gemma(self, opening_move):
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        
        if opening_move:
            gemma_move = opening_move
        else:
            template = "Instruction:\n{instruction}\n\nResponse:\n{response}"

            prompt = template.format(
                instruction=f"Predict the next chess move in the sequence {str(self.sequence)}",
                response="",)

            # output = self.model.generate(prompt, max_length=256)
        
            # gemma_move = output.split(' ')[-1].strip("'")

        if self.make_move(gemma_move):
            print(f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})')
            self.counter = 0
            return self.display_board()
        elif self.counter < 10:
            self.counter += 1
            print(self.counter)
            return self.call_gemma()
        else:
            print("Gemma quit...")
            return None

    def gemma_moves(self):
        # print(f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})")
        # time.sleep(3)
        if self.opening_moves and len(self.sequence)<len(self.opening_moves):
            return self.call_gemma(self.opening_moves[len(self.sequence)])
        else:
            return self.call_gemma(None)

    def player_moves(self, move):
        return self.make_move(move)

    # Function to display the board
    def display_board(self):
        # clear_output(wait=True)
        # display(SVG(chess.svg.board(board=self.board)))
        if self.arrow:
            board_svg = chess.svg.board(board=self.board, arrows=[self.arrow])
        else:
            board_svg = chess.svg.board(board=self.board)
        # return svg2png(bytestring=board_svg)
        return board_svg

    # Function to make a move
    def make_move(self, move):
        try:
            update = self.board.parse_san(move)
            self.board.push(update)
            # self.display_board()
            self.sequence.append(move)
            self.arrow = chess.svg.Arrow(update.from_square, update.to_square, color="#0000cccc")
            return True
        except:
            print(f"Invalid move '{move}'. Use algebraic notation (e.g., 'e4', 'Nf3', 'Bxc4') or ask Gemma for help.")
            return False
        
    def reset_board(self):
        self.board = chess.Board()
        self.sequence = []
        self.counter = 0
        self.arrow = None
        # self.board.reset
        return self.display_board()
    
    def generate_moves(self, move):
        valid_move = self.player_moves(move)
        if valid_move:
            yield self.display_board(), f"You played: {move}"
            time.sleep(2)
            yield self.display_board(), f"Gemma is thinking...(Current Sequence: {self.sequence} {len(self.sequence)})"
            time.sleep(3)
            yield self.gemma_moves(), f'Gemma plays {self.sequence[-1]}! (Current Sequence: {self.sequence} {len(self.sequence)})'
        else:
            print("Try again")
            yield self.display_board(), "Try again"

    def get_move_logs(self):
        return self.sequence
    
    def load_opening(self, opening_name, opening_moves):
        self.opening = True
        self.opening_name = opening_name
        self.opening_moves = opening_moves
        return f"Ok, lets play the {opening_name}! {opening_moves} Make your first move."
        

def main():
    end_game = False # Change this to False

    play_match = Game()
    play_match.display_board()

    while end_game is False:
        move = input("Your move (or 'No' to end game):")
        if 'No' in move:
            del play_match
            end_game = True
        else:
            play_match.player_moves(move)

if __name__ == '__main__':
    main()