NotASI commited on
Commit
c0d0de3
·
1 Parent(s): 59f9c0e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -45
app.py CHANGED
@@ -265,8 +265,8 @@ def preprocess_image(image: Image.Image) -> Image.Image:
265
  image_height = int(image.height * IMAGE_WIDTH / image.width)
266
  return image.resize((IMAGE_WIDTH, image_height))
267
 
268
- def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
269
- return "", chatbot + [[text_prompt, None]]
270
 
271
  def bot(
272
  google_key: str,
@@ -276,13 +276,13 @@ def bot(
276
  stop_sequences: str,
277
  top_k: int,
278
  top_p: float,
279
- chatbot: List[Tuple[str, str]]
280
  ):
281
  google_key = google_key or GEMINI_API_KEY
282
  if not google_key:
283
  raise ValueError("GOOGLE_API_KEY is not set. Please set it up.")
284
 
285
- text_prompt = chatbot[-1][0]
286
  genai.configure(api_key=google_key)
287
  generation_config = genai.types.GenerationConfig(
288
  temperature=temperature,
@@ -292,16 +292,19 @@ def bot(
292
  top_p=top_p,
293
  )
294
 
295
- model_name = "gemini-1.5-flash" # if image_prompt is None else "gemini-pro-vision"
296
  model = genai.GenerativeModel(model_name)
297
  inputs = [text_prompt] if image_prompt is None else [text_prompt, preprocess_image(image_prompt)]
298
 
299
- # Use gr.ChatInference for streaming response
300
  response = model.generate_content(inputs, stream=True, generation_config=generation_config)
301
- chatbot[-1][1] = ""
 
 
302
  for chunk in response:
303
- chatbot[-1][1] += chunk.text
304
- yield chatbot
 
 
305
 
306
  google_key_component = gr.Textbox(
307
  label = "GOOGLE API KEY",
@@ -314,17 +317,12 @@ image_prompt_component = gr.Image(
314
  type = "pil",
315
  label = "Image"
316
  )
317
- chatbot_component = gr.Chatbot(
318
- # label = 'Gemini',
319
- bubble_full_width = False
320
- )
321
  text_prompt_component = gr.Textbox(
322
  placeholder = "Chat with Gemini",
323
  label = "Ask me anything and press Enter"
324
  )
325
- run_button_component = gr.Button(
326
- "Run"
327
- )
328
  temperature_component = gr.Slider(
329
  minimum = 0,
330
  maximum = 1.0,
@@ -332,6 +330,7 @@ temperature_component = gr.Slider(
332
  step = 0.05,
333
  label = "Temperature"
334
  )
 
335
  max_output_tokens_component = gr.Slider(
336
  minimum = 1,
337
  maximum = 8192,
@@ -339,10 +338,12 @@ max_output_tokens_component = gr.Slider(
339
  step = 1,
340
  label = "Max Output Tokens"
341
  )
 
342
  stop_sequences_component = gr.Textbox(
343
  label = "Add stop sequence",
344
  placeholder = "STOP, END"
345
  )
 
346
  top_k_component = gr.Slider(
347
  minimum = 1,
348
  maximum = 40,
@@ -350,6 +351,7 @@ top_k_component = gr.Slider(
350
  step = 1,
351
  label = "Top-K"
352
  )
 
353
  top_p_component = gr.Slider(
354
  minimum = 0,
355
  maximum = 1,
@@ -360,8 +362,9 @@ top_p_component = gr.Slider(
360
 
361
  user_inputs = [
362
  text_prompt_component,
363
- chatbot_component
364
  ]
 
365
  bot_inputs = [
366
  google_key_component,
367
  image_prompt_component,
@@ -370,7 +373,7 @@ bot_inputs = [
370
  stop_sequences_component,
371
  top_k_component,
372
  top_p_component,
373
- chatbot_component
374
  ]
375
 
376
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
@@ -385,9 +388,7 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
385
  google_key_component.render()
386
  with gr.Row():
387
  image_prompt_component.render()
388
- chatbot_component.render()
389
  text_prompt_component.render()
390
- run_button_component.render()
391
  with gr.Accordion("Parameters", open=False):
392
  temperature_component.render()
393
  max_output_tokens_component.render()
@@ -396,31 +397,16 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
396
  top_k_component.render()
397
  top_p_component.render()
398
 
399
- # Use gr.ChatInference for streaming response
400
- chat_inference = gr.ChatInference(
401
- user, bot,
402
- [text_prompt_component, chatbot_component],
403
- [chatbot_component]
404
- )
405
- chat_inference.chatbot = chatbot_component
406
- chat_inference.api_key = google_key_component
407
- chat_inference.image_prompt = image_prompt_component
408
- chat_inference.temperature = temperature_component
409
- chat_inference.max_output_tokens = max_output_tokens_component
410
- chat_inference.stop_sequences = stop_sequences_component
411
- chat_inference.top_k = top_k_component
412
- chat_inference.top_p = top_p_component
413
-
414
- run_button_component.click(
415
- fn=chat_inference.submit,
416
- inputs=user_inputs,
417
- outputs=[text_prompt_component, chatbot_component]
418
- )
419
- text_prompt_component.submit(
420
- fn=chat_inference.submit,
421
- inputs=user_inputs,
422
- outputs=[text_prompt_component, chatbot_component]
423
  )
 
424
  with gr.Tab("Chat with Gemma 2"):
425
  gr.HTML(
426
  """
@@ -428,4 +414,4 @@ with gr.Blocks(theme = gr.themes.Soft()) as demo:
428
  """
429
  )
430
 
431
- demo.queue().launch(debug = True, show_error = True)
 
265
  image_height = int(image.height * IMAGE_WIDTH / image.width)
266
  return image.resize((IMAGE_WIDTH, image_height))
267
 
268
+ def user(text_prompt: str, history: List[Tuple[str, str]]):
269
+ return text_prompt, history
270
 
271
  def bot(
272
  google_key: str,
 
276
  stop_sequences: str,
277
  top_k: int,
278
  top_p: float,
279
+ history: List[Tuple[str, str]]
280
  ):
281
  google_key = google_key or GEMINI_API_KEY
282
  if not google_key:
283
  raise ValueError("GOOGLE_API_KEY is not set. Please set it up.")
284
 
285
+ text_prompt = history[-1][0]
286
  genai.configure(api_key=google_key)
287
  generation_config = genai.types.GenerationConfig(
288
  temperature=temperature,
 
292
  top_p=top_p,
293
  )
294
 
295
+ model_name = "gemini-1.5-flash"
296
  model = genai.GenerativeModel(model_name)
297
  inputs = [text_prompt] if image_prompt is None else [text_prompt, preprocess_image(image_prompt)]
298
 
 
299
  response = model.generate_content(inputs, stream=True, generation_config=generation_config)
300
+ response.resolve()
301
+
302
+ output_text = ""
303
  for chunk in response:
304
+ for i in range(0, len(chunk.text), 10):
305
+ output_text += chunk.text[i:i + 10]
306
+ time.sleep(0.01)
307
+ yield history + [(text_prompt, output_text)]
308
 
309
  google_key_component = gr.Textbox(
310
  label = "GOOGLE API KEY",
 
317
  type = "pil",
318
  label = "Image"
319
  )
320
+
 
 
 
321
  text_prompt_component = gr.Textbox(
322
  placeholder = "Chat with Gemini",
323
  label = "Ask me anything and press Enter"
324
  )
325
+
 
 
326
  temperature_component = gr.Slider(
327
  minimum = 0,
328
  maximum = 1.0,
 
330
  step = 0.05,
331
  label = "Temperature"
332
  )
333
+
334
  max_output_tokens_component = gr.Slider(
335
  minimum = 1,
336
  maximum = 8192,
 
338
  step = 1,
339
  label = "Max Output Tokens"
340
  )
341
+
342
  stop_sequences_component = gr.Textbox(
343
  label = "Add stop sequence",
344
  placeholder = "STOP, END"
345
  )
346
+
347
  top_k_component = gr.Slider(
348
  minimum = 1,
349
  maximum = 40,
 
351
  step = 1,
352
  label = "Top-K"
353
  )
354
+
355
  top_p_component = gr.Slider(
356
  minimum = 0,
357
  maximum = 1,
 
362
 
363
  user_inputs = [
364
  text_prompt_component,
365
+ gr.State([])
366
  ]
367
+
368
  bot_inputs = [
369
  google_key_component,
370
  image_prompt_component,
 
373
  stop_sequences_component,
374
  top_k_component,
375
  top_p_component,
376
+ gr.State([])
377
  ]
378
 
379
  with gr.Blocks(theme = gr.themes.Soft()) as demo:
 
388
  google_key_component.render()
389
  with gr.Row():
390
  image_prompt_component.render()
 
391
  text_prompt_component.render()
 
392
  with gr.Accordion("Parameters", open=False):
393
  temperature_component.render()
394
  max_output_tokens_component.render()
 
397
  top_k_component.render()
398
  top_p_component.render()
399
 
400
+ chat_interface = gr.ChatInterface(
401
+ fn=bot,
402
+ user_fn=user,
403
+ inputs=bot_inputs,
404
+ outputs="chatbot",
405
+ submit=text_prompt_component,
406
+ state="chatbot",
407
+ queue=False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
408
  )
409
+
410
  with gr.Tab("Chat with Gemma 2"):
411
  gr.HTML(
412
  """
 
414
  """
415
  )
416
 
417
+ demo.queue().launch(debug = True, show_error = True)