File size: 5,514 Bytes
de15782 03253cd de15782 99caaef 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd de15782 03253cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from typing import Any, List, Optional, Tuple, Literal
import google.generativeai as genai
from dotenv import load_dotenv
import os
from google.generativeai.types.generation_types import GenerateContentResponse
import gradio as gr
from PIL import Image
import numpy as np
load_dotenv()
GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "Enter correct API key")
genai.configure(api_key=GOOGLE_API_KEY)
def save_image(file_input: str, image_name: str) -> str:
"""Saves images into the memory.
Args:
file_input (str): file input from Gradio
image_name (str): file name to be saved
Returns:
str: path of the saved image
"""
# Convert the input to a PIL image
image_pil: Image.Image = Image.fromarray(np.uint8(file_input))
# Define the directory where the image will be saved
save_directory = "images"
# Check if the directory exists, create it if not
if not os.path.exists(save_directory):
os.makedirs(save_directory, exist_ok=True)
# Define the full path to save the image
image_path: str = os.path.join(save_directory, image_name)
# Save the image
image_pil.save(image_path)
return image_path
def generate_response(
text_input: str,
file_inputs: Optional[List[str]] = None,
chat_history: Optional[List[Tuple[str, str]]] = None,
) -> Tuple[str, Any | List[Any]]:
"""Generates response using gemini-1.5-flash model.
Args:
text_input (str): user input
file_inputs (List[str], optional): file paths of the uploaded images. Defaults to None.
chat_history (List[Tuple[str, str]], optional): chat history of the user. Defaults to None.
Returns:
Tuple[str, Any | List[Any]]: returns response and chat history
"""
# Upload the files (images) and print a confirmation.
image_paths: List[str] = []
if file_inputs is not None:
for idx, file_input in enumerate(file_inputs):
image_name: str = f"image_{idx + 1}.jpg"
image_path: str = save_image(file_input, image_name)
image_paths.append(image_path)
# Choose a Gemini API model.
model = genai.GenerativeModel(model_name="gemini-1.5-flash")
# Initialize chat history if None
if chat_history is None:
chat_history = []
# Convert chat history into the required format for Gemini API
chat_history_content = []
for user_message, bot_response in chat_history:
chat_history_content.append({"role": "user", "parts": [{"text": user_message}]})
chat_history_content.append(
{"role": "model", "parts": [{"text": bot_response}]}
)
chat: genai.ChatSession = model.start_chat(history=chat_history_content)
# Open images and pass them with text_input if available
images = (
[Image.open(image_path) for image_path in image_paths] if image_paths else None
)
# Prompt the model with text and the uploaded images if available
if images:
response: GenerateContentResponse = chat.send_message([*images, text_input])
else:
response: GenerateContentResponse = chat.send_message(text_input)
# Append the new message to chat history in Gradio format (user, bot)
chat_history.append((text_input, response.text))
return response.text, chat_history
# Create a Gradio interface with Blocks
with gr.Blocks(title="Gemini vision") as demo:
gr.Markdown("# Chat Bot M1N9")
# Define the Chatbot component
chatbot = gr.Chatbot(
[], elem_id="chatbot", height=700, show_share_button=True, show_copy_button=True
)
# Define the Textbox and Image components
msg = gr.Textbox(show_copy_button=True, placeholder="Type your message here...")
# Row for multiple image inputs
with gr.Row():
img1 = gr.Image()
img2 = gr.Image()
img3 = gr.Image()
img4 = gr.Image()
btn = gr.Button("Submit")
# Define the ClearButton component
clear = gr.ClearButton([msg, img1, img2, img3, img4, chatbot])
# Set the submit function for the Textbox and Image
def submit_message(msg: str, img1, img2, img3, img4, chat_history):
"""Takes response from the generated response and displays it in the chatbot.
Args:
msg (str): user input
img1 (_type_): image input
img2 (_type_): image input
img3 (_type_): image input
img4 (_type_): image input
chat_history (_type_): chat history of the user
Returns:
_type_: _description_
"""
# Collect all images into a list
image_list = [img1, img2, img3, img4]
# Filter out None values in case fewer than 4 images are uploaded
image_list = [img for img in image_list if img is not None]
# Call the generate_response with the list of images
response, chat_history = generate_response(msg, image_list, chat_history)
# Return the updated chat history and clear input fields
return "", img1, img2, img3, img4, chat_history
# Bind the submit function to both the submit action of Textbox and the button click
msg.submit(
submit_message,
[msg, img1, img2, img3, img4, chatbot],
[msg, img1, img2, img3, img4, chatbot],
)
btn.click(
submit_message,
[msg, img1, img2, img3, img4, chatbot],
[msg, img1, img2, img3, img4, chatbot],
)
# Launch the Gradio interface
demo.launch(debug=True, share=True)
|