|
from typing import Any, List, Optional, Tuple, Literal |
|
import google.generativeai as genai |
|
from dotenv import load_dotenv |
|
import os |
|
from google.generativeai.types.generation_types import GenerateContentResponse |
|
import gradio as gr |
|
from PIL import Image |
|
import numpy as np |
|
|
|
load_dotenv() |
|
|
|
GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "Enter correct API key") |
|
genai.configure(api_key=GOOGLE_API_KEY) |
|
|
|
|
|
def save_image(file_input: str, image_name: str) -> str: |
|
"""Saves images into the memory. |
|
|
|
Args: |
|
file_input (str): file input from Gradio |
|
image_name (str): file name to be saved |
|
|
|
Returns: |
|
str: path of the saved image |
|
""" |
|
|
|
image_pil: Image.Image = Image.fromarray(np.uint8(file_input)) |
|
|
|
|
|
save_directory = "images" |
|
|
|
|
|
if not os.path.exists(save_directory): |
|
os.makedirs(save_directory, exist_ok=True) |
|
|
|
|
|
image_path: str = os.path.join(save_directory, image_name) |
|
|
|
|
|
image_pil.save(image_path) |
|
|
|
return image_path |
|
|
|
|
|
def generate_response( |
|
text_input: str, |
|
file_inputs: Optional[List[str]] = None, |
|
chat_history: Optional[List[Tuple[str, str]]] = None, |
|
) -> Tuple[str, Any | List[Any]]: |
|
"""Generates response using gemini-1.5-flash model. |
|
|
|
Args: |
|
text_input (str): user input |
|
file_inputs (List[str], optional): file paths of the uploaded images. Defaults to None. |
|
chat_history (List[Tuple[str, str]], optional): chat history of the user. Defaults to None. |
|
|
|
Returns: |
|
Tuple[str, Any | List[Any]]: returns response and chat history |
|
""" |
|
|
|
image_paths: List[str] = [] |
|
if file_inputs is not None: |
|
for idx, file_input in enumerate(file_inputs): |
|
image_name: str = f"image_{idx + 1}.jpg" |
|
image_path: str = save_image(file_input, image_name) |
|
image_paths.append(image_path) |
|
|
|
|
|
model = genai.GenerativeModel(model_name="gemini-1.5-flash") |
|
|
|
|
|
if chat_history is None: |
|
chat_history = [] |
|
|
|
|
|
chat_history_content = [] |
|
for user_message, bot_response in chat_history: |
|
chat_history_content.append({"role": "user", "parts": [{"text": user_message}]}) |
|
chat_history_content.append( |
|
{"role": "model", "parts": [{"text": bot_response}]} |
|
) |
|
|
|
chat: genai.ChatSession = model.start_chat(history=chat_history_content) |
|
|
|
|
|
images = ( |
|
[Image.open(image_path) for image_path in image_paths] if image_paths else None |
|
) |
|
|
|
|
|
if images: |
|
response: GenerateContentResponse = chat.send_message([*images, text_input]) |
|
else: |
|
response: GenerateContentResponse = chat.send_message(text_input) |
|
|
|
|
|
chat_history.append((text_input, response.text)) |
|
|
|
return response.text, chat_history |
|
|
|
|
|
|
|
with gr.Blocks(title="Gemini vision") as demo: |
|
gr.Markdown("# Chat Bot M1N9") |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
[], elem_id="chatbot", height=700, show_share_button=True, show_copy_button=True |
|
) |
|
|
|
|
|
msg = gr.Textbox(show_copy_button=True, placeholder="Type your message here...") |
|
|
|
|
|
with gr.Row(): |
|
img1 = gr.Image() |
|
img2 = gr.Image() |
|
img3 = gr.Image() |
|
img4 = gr.Image() |
|
|
|
btn = gr.Button("Submit") |
|
|
|
|
|
clear = gr.ClearButton([msg, img1, img2, img3, img4, chatbot]) |
|
|
|
|
|
def submit_message(msg: str, img1, img2, img3, img4, chat_history): |
|
"""Takes response from the generated response and displays it in the chatbot. |
|
|
|
Args: |
|
msg (str): user input |
|
img1 (_type_): image input |
|
img2 (_type_): image input |
|
img3 (_type_): image input |
|
img4 (_type_): image input |
|
chat_history (_type_): chat history of the user |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
image_list = [img1, img2, img3, img4] |
|
|
|
image_list = [img for img in image_list if img is not None] |
|
|
|
|
|
response, chat_history = generate_response(msg, image_list, chat_history) |
|
|
|
|
|
return "", img1, img2, img3, img4, chat_history |
|
|
|
|
|
msg.submit( |
|
submit_message, |
|
[msg, img1, img2, img3, img4, chatbot], |
|
[msg, img1, img2, img3, img4, chatbot], |
|
) |
|
btn.click( |
|
submit_message, |
|
[msg, img1, img2, img3, img4, chatbot], |
|
[msg, img1, img2, img3, img4, chatbot], |
|
) |
|
|
|
|
|
demo.launch(debug=True, share=True) |
|
|