Spaces:
Sleeping
Sleeping
File size: 3,007 Bytes
2b49a9b 03335e7 2b49a9b 03335e7 2b49a9b d8fec69 03335e7 a8472ea 03335e7 2b49a9b 03335e7 2b49a9b 03335e7 2b49a9b 03335e7 817e88e 74cda9e 817e88e 74cda9e 03335e7 2b49a9b 03335e7 2b49a9b 03335e7 2b49a9b 03335e7 2b49a9b 03335e7 2b49a9b 3201bb7 2b49a9b 03335e7 2ca989c 0ee6f09 2ca989c 0ee6f09 d8fec69 09e95ed d8fec69 2ca989c e53d595 2b49a9b 2ca989c 2b49a9b 17fca8c 0ee6f09 d8fec69 2ca989c 2b49a9b 03335e7 2b49a9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import os
os.environ["KERAS_BACKEND"] = "torch" # "jax", "torch" or "tensorflow"
import gradio as gr
import keras_nlp
import keras
import spaces
import torch
from typing import Iterator
import time
from chess_board import Game
print(f"Is CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
MAX_INPUT_TOKEN_LENGTH = 4096
MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 128
# model_id = "hf://google/gemma-2b-keras"
# model_id = "hf://google/gemma-2-2b-it"
model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
tokenizer = model.preprocessor.tokenizer
DESCRIPTION = """
# Gemma 2B
**Welcome to the Gemma Chess Chatbot!**
This game mode allows you to play a game against Gemma, the input must be in algebraic notation. \n
If you need help learning algebraic notation ask Gemma!
"""
# @spaces.GPU
def generate(
message: str,
chat_history: list[dict],
max_new_tokens: int = 1024,
) -> Iterator[str]:
input_ids = tokenizer.tokenize(message)
if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
response = model.generate(message, max_length=max_new_tokens)
outputs = ""
for char in response:
outputs += char
yield outputs
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
],
stop_btn=None,
examples=[
["Hi Gemma, what is a good first move in chess?"],
["How does the Knight move?"]
],
cache_examples=False,
type="messages",
)
def display_text(seq):
# Function to display some predefined text
# seq = play_match.get_move_logs()
seq = ['e4','e5']
for move in seq:
yield move
time.sleep(2)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
play_match = Game()
# chess_png = gr.Image(play_match.display_board())
with gr.Row():
board_image = gr.HTML(play_match.display_board())
with gr.Column():
chat_interface.render()
text_output = gr.Label(label="Display Text for Logs")
move_input = gr.Textbox(label="Enter your move in algebraic notation (e.g., e4, Nf3, Bxc4)")
btn = gr.Button("Submit Move")
btn.click(play_match.generate_moves, inputs=move_input, outputs=board_image)
btn.click(display_text, inputs=play_match.get_move_logs, outputs=text_output)
reset_btn = gr.Button("Reset Game")
reset_btn.click(play_match.reset_board, outputs=board_image)
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|