File size: 3,007 Bytes
2b49a9b
 
 
03335e7
2b49a9b
 
 
 
03335e7
2b49a9b
d8fec69
03335e7
a8472ea
03335e7
 
2b49a9b
 
03335e7
2b49a9b
03335e7
2b49a9b
 
03335e7
817e88e
74cda9e
817e88e
74cda9e
03335e7
 
2b49a9b
 
03335e7
2b49a9b
 
 
 
 
 
03335e7
2b49a9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03335e7
 
2b49a9b
 
 
 
 
03335e7
 
2b49a9b
 
3201bb7
 
2b49a9b
 
 
03335e7
 
2ca989c
0ee6f09
2ca989c
0ee6f09
d8fec69
09e95ed
 
d8fec69
2ca989c
e53d595
2b49a9b
 
 
 
 
 
 
 
 
 
2ca989c
 
2b49a9b
 
17fca8c
0ee6f09
d8fec69
2ca989c
2b49a9b
 
 
 
03335e7
 
2b49a9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
os.environ["KERAS_BACKEND"] = "torch"  # "jax", "torch" or "tensorflow"

import gradio as gr
import keras_nlp
import keras
import spaces
import torch

from typing import Iterator
import time

from chess_board import Game


print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")

MAX_INPUT_TOKEN_LENGTH = 4096

MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 128

# model_id = "hf://google/gemma-2b-keras"
# model_id = "hf://google/gemma-2-2b-it"

model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'


model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
tokenizer = model.preprocessor.tokenizer

DESCRIPTION = """
# Gemma 2B
**Welcome to the Gemma Chess Chatbot!**

This game mode allows you to play a game against Gemma, the input must be in algebraic notation. \n
If you need help learning algebraic notation ask Gemma!
"""

# @spaces.GPU
def generate(
    message: str,
    chat_history: list[dict],
    max_new_tokens: int = 1024,
    ) -> Iterator[str]:

    input_ids = tokenizer.tokenize(message)
    
    if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")

    response = model.generate(message, max_length=max_new_tokens)

    outputs = ""
    
    for char in response:
        outputs += char
        yield outputs


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Hi Gemma, what is a good first move in chess?"],
        ["How does the Knight move?"]
    ],
    cache_examples=False,
    type="messages",
)


def display_text(seq):
    # Function to display some predefined text
    # seq = play_match.get_move_logs()
    seq = ['e4','e5']
    for move in seq:
        yield move
        time.sleep(2)
    
with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
        
    play_match = Game()

    # chess_png = gr.Image(play_match.display_board())
    with gr.Row():
        board_image = gr.HTML(play_match.display_board())
        with gr.Column():
            chat_interface.render()

    text_output = gr.Label(label="Display Text for Logs")
    
    move_input = gr.Textbox(label="Enter your move in algebraic notation (e.g., e4, Nf3, Bxc4)")
    btn = gr.Button("Submit Move")
    btn.click(play_match.generate_moves, inputs=move_input, outputs=board_image)
    
    btn.click(display_text, inputs=play_match.get_move_logs, outputs=text_output)
    

    reset_btn = gr.Button("Reset Game")
    reset_btn.click(play_match.reset_board, outputs=board_image)


if __name__ == "__main__":
    demo.queue(max_size=20).launch()