File size: 4,181 Bytes
e7f5c27
 
2b49a9b
 
 
03335e7
003246a
 
e7f5c27
003246a
03335e7
2b49a9b
d8fec69
03335e7
a8472ea
9c3e344
003246a
03335e7
ebf05e2
003246a
 
03335e7
 
2b49a9b
501eb80
 
2b49a9b
9c3e344
 
501eb80
 
 
 
 
 
 
 
 
 
 
03335e7
2b49a9b
003246a
 
ebf05e2
003246a
 
 
 
 
 
 
 
ebf05e2
 
04bd3a7
2b49a9b
 
 
 
 
 
003246a
2b49a9b
 
 
0a0ab7b
2b49a9b
 
 
 
9c3e344
 
 
 
 
 
 
 
 
 
003246a
9c3e344
2b49a9b
 
 
 
b881822
 
 
2b49a9b
 
 
03335e7
 
2ca989c
2739b6b
2b49a9b
 
 
 
be93a92
3172baf
 
2b49a9b
 
 
9c3e344
2ca989c
9c3e344
 
 
2b49a9b
9c3e344
 
 
 
2b49a9b
9c3e344
 
 
03335e7
9c3e344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03335e7
2b49a9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import spaces

import os
os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

import gradio as gr
import keras_nlp
import keras
# import spaces
import torch

from typing import Iterator
import time

from chess_board import Game
from datasets import load_dataset
import google.generativeai as genai


print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")


DESCRIPTION = """
# Chess Tutor AI
**Welcome to the Chess Chatbot!**

The goal of this project is to showcase the use of AI in learning chess. This app allows you to play a game against a custom fine-tuned model (Gemma 2B).\n
The challenge is that input must be in *algebraic notation*.

## Features

### For New & Beginner Players
- The chat interface uses the Gemini API, if you need help with chess rules or learning algebraic notation, just ask!

### For Advanced Users
- Pick an opening to play, and ask Gemini for more info.

Enjoy your game!  
**- Valentin**
"""

api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key = api_key)

model = genai.GenerativeModel(model_name='gemini-1.5-flash-latest')

chat = model.start_chat()

ds = load_dataset("Lichess/chess-openings", split="train")
df = ds.to_pandas()

opening_names = df['name'].unique().tolist()


# @spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    ) -> Iterator[str]:

    response = chat.send_message(message)

    outputs = ""
    
    for char in response.text:
        outputs += char
        yield outputs


def get_opening_details(opening_name):
    opening_data = df[df['name'] == opening_name].iloc[0]
    moves = opening_data['pgn']
    return f"Opening: {opening_data['name']}\nMoves: {moves}"

def get_move_list(opening_name):
    opening_data = df[df['name'] == opening_name].iloc[0]
    moves = opening_data['pgn']
    pgn_string = moves.split()
    return [move for idx,move in enumerate(pgn_string[1:],1) if idx%3!=0]
   

chat_interface = gr.ChatInterface(
    fn=generate,
    stop_btn=None,
    examples=[
        ["Hi Gemini, what is a good first move in chess?"],
        ["How does the Knight move?"],
        ["Explain algebraic notation for capturing a piece in chess?"]
    ],
    cache_examples=False,
    type="messages",
)

    
with gr.Blocks(css_paths="styles.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
        
    play_match = Game()

    with gr.Row():
        with gr.Column():
            board_image = gr.HTML(play_match.display_board())
        with gr.Column():
            chat_interface.render()

    game_logs = gr.Label(label="Game Logs", elem_classes=["big-text"])
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("### Play a Match vs Gemma")

            move_input = gr.Textbox(label="Enter your move in algebraic notation: (e.g., e4, Nf3, Bxc4)")
            submit_move = gr.Button("Submit Move")
            submit_move.click(play_match.generate_moves, inputs=move_input, outputs=[board_image, game_logs])
            submit_move.click(lambda x: gr.update(value=''), [],[move_input])

            reset_board = gr.Button("Reset Game")
            reset_board.click(play_match.reset_board, outputs=board_image)
            reset_board.click(lambda x: gr.update(value=''), [],[game_logs])

        with gr.Column():
            gr.Markdown("### Chess Openings Explorer")

            opening_choice = gr.Dropdown(label="Choose a Chess Opening", choices=opening_names)
            opening_output = gr.Textbox(label="Opening Details", lines=4)
            opening_moves = gr.State()

            opening_choice.change(fn=get_opening_details, inputs=opening_choice, outputs=opening_output)
            opening_choice.change(fn=get_move_list, inputs=opening_choice, outputs=opening_moves)


            load_opening = gr.Button("Load Opening")
            load_opening.click(play_match.reset_board, outputs=board_image)
            load_opening.click(play_match.load_opening, inputs=[opening_choice, opening_moves], outputs=game_logs)
    
if __name__ == "__main__":
    demo.queue(max_size=20).launch()