Spaces:
Sleeping
Sleeping
valentin urena
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,8 @@ import time
|
|
12 |
|
13 |
from chess_board import Game
|
14 |
|
|
|
|
|
15 |
|
16 |
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
17 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
@@ -22,13 +24,13 @@ MAX_NEW_TOKENS = 2048
|
|
22 |
DEFAULT_MAX_NEW_TOKENS = 128
|
23 |
|
24 |
# model_id = "hf://google/gemma-2b-keras"
|
25 |
-
model_id = "hf://google/gemma-2-2b-it"
|
26 |
|
27 |
# model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
|
28 |
|
29 |
|
30 |
-
model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
|
31 |
-
tokenizer = model.preprocessor.tokenizer
|
32 |
|
33 |
DESCRIPTION = """
|
34 |
# Gemma 2B
|
@@ -38,6 +40,16 @@ This game mode allows you to play a game against Gemma, the input must be in alg
|
|
38 |
If you need help learning algebraic notation ask Gemma!
|
39 |
"""
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
# @spaces.GPU
|
42 |
def generate(
|
43 |
message: str,
|
@@ -45,13 +57,15 @@ def generate(
|
|
45 |
max_new_tokens: int = 1024,
|
46 |
) -> Iterator[str]:
|
47 |
|
48 |
-
input_ids = tokenizer.tokenize(message)
|
49 |
|
50 |
-
if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
|
51 |
-
|
52 |
-
|
|
|
|
|
53 |
|
54 |
-
response =
|
55 |
|
56 |
outputs = ""
|
57 |
|
|
|
12 |
|
13 |
from chess_board import Game
|
14 |
|
15 |
+
import google.generativeai as genai
|
16 |
+
|
17 |
|
18 |
print(f"Is CUDA available: {torch.cuda.is_available()}")
|
19 |
print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
|
|
24 |
DEFAULT_MAX_NEW_TOKENS = 128
|
25 |
|
26 |
# model_id = "hf://google/gemma-2b-keras"
|
27 |
+
# model_id = "hf://google/gemma-2-2b-it"
|
28 |
|
29 |
# model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
|
30 |
|
31 |
|
32 |
+
# model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
|
33 |
+
# tokenizer = model.preprocessor.tokenizer
|
34 |
|
35 |
DESCRIPTION = """
|
36 |
# Gemma 2B
|
|
|
40 |
If you need help learning algebraic notation ask Gemma!
|
41 |
"""
|
42 |
|
43 |
+
|
44 |
+
user_secrets = UserSecretsClient()
|
45 |
+
api_key = user_secrets.get_secret("GEMINI_API_KEY")
|
46 |
+
genai.configure(api_key = api_key)
|
47 |
+
|
48 |
+
model = genai.GenerativeModel(model_name='gemini-1.5-flash-latest')
|
49 |
+
|
50 |
+
# Chat
|
51 |
+
chat = model.start_chat()
|
52 |
+
|
53 |
# @spaces.GPU
|
54 |
def generate(
|
55 |
message: str,
|
|
|
57 |
max_new_tokens: int = 1024,
|
58 |
) -> Iterator[str]:
|
59 |
|
60 |
+
# input_ids = tokenizer.tokenize(message)
|
61 |
|
62 |
+
# if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
|
63 |
+
# input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
|
64 |
+
# gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
65 |
+
|
66 |
+
# response = model.generate(message, max_length=max_new_tokens)
|
67 |
|
68 |
+
response = chat.send_message(message)
|
69 |
|
70 |
outputs = ""
|
71 |
|