File size: 3,362 Bytes
2b49a9b
 
 
03335e7
2b49a9b
 
 
 
03335e7
2b49a9b
d8fec69
03335e7
a8472ea
03335e7
ebf05e2
 
03335e7
2b49a9b
 
03335e7
2b49a9b
03335e7
2b49a9b
 
03335e7
817e88e
ebf05e2
817e88e
aa6dea3
03335e7
 
ebf05e2
 
03335e7
2b49a9b
501eb80
 
2b49a9b
501eb80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03335e7
2b49a9b
6a5a1f2
ebf05e2
 
 
 
 
 
 
2b49a9b
 
 
 
 
 
 
ebf05e2
2b49a9b
ebf05e2
 
 
 
 
2b49a9b
ebf05e2
2b49a9b
 
 
d16e146
2b49a9b
 
 
 
 
 
 
 
3201bb7
 
2b49a9b
 
 
03335e7
 
2ca989c
7e0d668
2b49a9b
 
 
 
 
 
3172baf
 
2b49a9b
 
 
3318dc1
2ca989c
2b49a9b
 
9f1409e
a73ab00
0ee6f09
9f1409e
2ca989c
2b49a9b
 
 
 
03335e7
 
2b49a9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

import gradio as gr
import keras_nlp
import keras
import spaces
import torch

from typing import Iterator
import time

from chess_board import Game

import google.generativeai as genai


print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

MAX_INPUT_TOKEN_LENGTH = 4096

MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 128

# model_id = "hf://google/gemma-2b-keras"
# model_id = "hf://google/gemma-2-2b-it"

# model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'


# model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
# tokenizer = model.preprocessor.tokenizer

DESCRIPTION = """
# Chess Tutor AI
**Welcome to the Chess Chatbot!**

The goal of this project is to showcase the use of AI in learning chess. This app allows you to play a game against a custom fine-tuned model (Gemma 2B). The challenge is that input must be in *algebraic notation*.

## Features

### For New & Beginner Players
- The chat interface uses the Gemini API, if you need help with chess rules or learning algebraic notation, just ask!

### For Advanced Users
- Pick an opening to play, and ask Gemini for more info.


<br>

Enjoy your game!  
**- Valentin**
"""

api_key = os.getenv("GEMINI_API_KEY")
genai.configure(api_key = api_key)

model = genai.GenerativeModel(model_name='gemini-1.5-flash-latest')

# Chat
chat = model.start_chat()

# @spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    ) -> Iterator[str]:

    # input_ids = tokenizer.tokenize(message)
    
    # if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
    #     input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
    #     gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    # response = model.generate(message, max_length=max_new_tokens)

    response = chat.send_message(message)

    outputs = ""
    
    for char in response.text:
        outputs += char
        yield outputs


chat_interface = gr.ChatInterface(
    fn=generate,
    stop_btn=None,
    examples=[
        ["Hi Gemma, what is a good first move in chess?"],
        ["How does the Knight move?"]
    ],
    cache_examples=False,
    type="messages",
)

    
with gr.Blocks(css_paths="styles.css", fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
        
    play_match = Game()

    # chess_png = gr.Image(play_match.display_board())
    with gr.Row():
        with gr.Column():
            board_image = gr.HTML(play_match.display_board())
        with gr.Column():
            chat_interface.render()

    game_logs = gr.Label(label="Game Logs", elem_id="game_logs_label")
    
    move_input = gr.Textbox(label="Enter your move in algebraic notation (e.g., e4, Nf3, Bxc4)")
    btn = gr.Button("Submit Move")
    btn.click(play_match.generate_moves, inputs=move_input, outputs=[board_image, game_logs])
    btn.click(lambda x: gr.update(value=''), [],[move_input])
    
    # btn.click(display_text, inputs=play_match.get_move_logs, outputs=text_output)
    

    reset_btn = gr.Button("Reset Game")
    reset_btn.click(play_match.reset_board, outputs=board_image)


if __name__ == "__main__":
    demo.queue(max_size=20).launch()