valentin urena commited on
Commit
ebf05e2
·
verified ·
1 Parent(s): 6c50855

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -8
app.py CHANGED
@@ -12,6 +12,8 @@ import time
12
 
13
  from chess_board import Game
14
 
 
 
15
 
16
  print(f"Is CUDA available: {torch.cuda.is_available()}")
17
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
@@ -22,13 +24,13 @@ MAX_NEW_TOKENS = 2048
22
  DEFAULT_MAX_NEW_TOKENS = 128
23
 
24
  # model_id = "hf://google/gemma-2b-keras"
25
- model_id = "hf://google/gemma-2-2b-it"
26
 
27
  # model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
28
 
29
 
30
- model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
31
- tokenizer = model.preprocessor.tokenizer
32
 
33
  DESCRIPTION = """
34
  # Gemma 2B
@@ -38,6 +40,16 @@ This game mode allows you to play a game against Gemma, the input must be in alg
38
  If you need help learning algebraic notation ask Gemma!
39
  """
40
 
 
 
 
 
 
 
 
 
 
 
41
  # @spaces.GPU
42
  def generate(
43
  message: str,
@@ -45,13 +57,15 @@ def generate(
45
  max_new_tokens: int = 1024,
46
  ) -> Iterator[str]:
47
 
48
- input_ids = tokenizer.tokenize(message)
49
 
50
- if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
51
- input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
52
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
 
53
 
54
- response = model.generate(message, max_length=max_new_tokens)
55
 
56
  outputs = ""
57
 
 
12
 
13
  from chess_board import Game
14
 
15
+ import google.generativeai as genai
16
+
17
 
18
  print(f"Is CUDA available: {torch.cuda.is_available()}")
19
  print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
 
24
  DEFAULT_MAX_NEW_TOKENS = 128
25
 
26
  # model_id = "hf://google/gemma-2b-keras"
27
+ # model_id = "hf://google/gemma-2-2b-it"
28
 
29
  # model_id = 'kaggle://valentinbaltazar/gemma-chess/keras/gemma_2b_en_chess'
30
 
31
 
32
+ # model = keras_nlp.models.GemmaCausalLM.from_preset(model_id)
33
+ # tokenizer = model.preprocessor.tokenizer
34
 
35
  DESCRIPTION = """
36
  # Gemma 2B
 
40
  If you need help learning algebraic notation ask Gemma!
41
  """
42
 
43
+
44
+ user_secrets = UserSecretsClient()
45
+ api_key = user_secrets.get_secret("GEMINI_API_KEY")
46
+ genai.configure(api_key = api_key)
47
+
48
+ model = genai.GenerativeModel(model_name='gemini-1.5-flash-latest')
49
+
50
+ # Chat
51
+ chat = model.start_chat()
52
+
53
  # @spaces.GPU
54
  def generate(
55
  message: str,
 
57
  max_new_tokens: int = 1024,
58
  ) -> Iterator[str]:
59
 
60
+ # input_ids = tokenizer.tokenize(message)
61
 
62
+ # if len(input_ids) > MAX_INPUT_TOKEN_LENGTH:
63
+ # input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
64
+ # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
65
+
66
+ # response = model.generate(message, max_length=max_new_tokens)
67
 
68
+ response = chat.send_message(message)
69
 
70
  outputs = ""
71