File size: 5,514 Bytes
de15782
03253cd
 
 
de15782
99caaef
03253cd
 
 
 
 
de15782
03253cd
 
de15782
 
 
 
 
 
 
 
 
 
 
03253cd
de15782
 
03253cd
 
de15782
03253cd
 
 
de15782
03253cd
de15782
 
03253cd
 
 
 
 
de15782
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03253cd
de15782
03253cd
 
de15782
 
03253cd
 
 
 
 
 
 
 
 
 
 
 
 
de15782
 
 
03253cd
de15782
03253cd
 
de15782
 
 
03253cd
 
 
de15782
03253cd
de15782
03253cd
 
 
 
de15782
 
03253cd
 
 
 
 
 
de15782
 
 
03253cd
 
 
 
 
 
 
 
 
 
de15782
03253cd
 
 
 
 
 
de15782
 
 
 
 
 
 
 
 
 
 
 
 
 
03253cd
 
 
 
de15782
03253cd
 
de15782
03253cd
de15782
03253cd
 
de15782
 
 
 
 
 
 
 
 
 
03253cd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from typing import Any, List, Optional, Tuple, Literal
import google.generativeai as genai
from dotenv import load_dotenv
import os
from google.generativeai.types.generation_types import GenerateContentResponse
import gradio as gr
from PIL import Image
import numpy as np

load_dotenv()

GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "Enter correct API key")
genai.configure(api_key=GOOGLE_API_KEY)


def save_image(file_input: str, image_name: str) -> str:
    """Saves images into the memory.

    Args:
        file_input (str): file input from Gradio
        image_name (str): file name to be saved

    Returns:
        str: path of the saved image
    """
    # Convert the input to a PIL image
    image_pil: Image.Image = Image.fromarray(np.uint8(file_input))

    # Define the directory where the image will be saved
    save_directory = "images"

    # Check if the directory exists, create it if not
    if not os.path.exists(save_directory):
        os.makedirs(save_directory, exist_ok=True)

    # Define the full path to save the image
    image_path: str = os.path.join(save_directory, image_name)

    # Save the image
    image_pil.save(image_path)

    return image_path


def generate_response(
    text_input: str,
    file_inputs: Optional[List[str]] = None,
    chat_history: Optional[List[Tuple[str, str]]] = None,
) -> Tuple[str, Any | List[Any]]:
    """Generates response using gemini-1.5-flash model.

    Args:
        text_input (str): user input
        file_inputs (List[str], optional): file paths of the uploaded images. Defaults to None.
        chat_history (List[Tuple[str, str]], optional): chat history of the user. Defaults to None.

    Returns:
        Tuple[str, Any | List[Any]]: returns response and chat history
    """
    # Upload the files (images) and print a confirmation.
    image_paths: List[str] = []
    if file_inputs is not None:
        for idx, file_input in enumerate(file_inputs):
            image_name: str = f"image_{idx + 1}.jpg"
            image_path: str = save_image(file_input, image_name)
            image_paths.append(image_path)

    # Choose a Gemini API model.
    model = genai.GenerativeModel(model_name="gemini-1.5-flash")

    # Initialize chat history if None
    if chat_history is None:
        chat_history = []

    # Convert chat history into the required format for Gemini API
    chat_history_content = []
    for user_message, bot_response in chat_history:
        chat_history_content.append({"role": "user", "parts": [{"text": user_message}]})
        chat_history_content.append(
            {"role": "model", "parts": [{"text": bot_response}]}
        )

    chat: genai.ChatSession = model.start_chat(history=chat_history_content)

    # Open images and pass them with text_input if available
    images = (
        [Image.open(image_path) for image_path in image_paths] if image_paths else None
    )

    # Prompt the model with text and the uploaded images if available
    if images:
        response: GenerateContentResponse = chat.send_message([*images, text_input])
    else:
        response: GenerateContentResponse = chat.send_message(text_input)

    # Append the new message to chat history in Gradio format (user, bot)
    chat_history.append((text_input, response.text))

    return response.text, chat_history


# Create a Gradio interface with Blocks
with gr.Blocks(title="Gemini vision") as demo:
    gr.Markdown("# Chat Bot M1N9")

    # Define the Chatbot component
    chatbot = gr.Chatbot(
        [], elem_id="chatbot", height=700, show_share_button=True, show_copy_button=True
    )

    # Define the Textbox and Image components
    msg = gr.Textbox(show_copy_button=True, placeholder="Type your message here...")

    # Row for multiple image inputs
    with gr.Row():
        img1 = gr.Image()
        img2 = gr.Image()
        img3 = gr.Image()
        img4 = gr.Image()

    btn = gr.Button("Submit")

    # Define the ClearButton component
    clear = gr.ClearButton([msg, img1, img2, img3, img4, chatbot])

    # Set the submit function for the Textbox and Image
    def submit_message(msg: str, img1, img2, img3, img4, chat_history):
        """Takes response from the generated response and displays it in the chatbot.

        Args:
            msg (str): user input
            img1 (_type_): image input
            img2 (_type_): image input
            img3 (_type_): image input
            img4 (_type_): image input
            chat_history (_type_): chat history of the user

        Returns:
            _type_: _description_
        """
        # Collect all images into a list
        image_list = [img1, img2, img3, img4]
        # Filter out None values in case fewer than 4 images are uploaded
        image_list = [img for img in image_list if img is not None]

        # Call the generate_response with the list of images
        response, chat_history = generate_response(msg, image_list, chat_history)

        # Return the updated chat history and clear input fields
        return "", img1, img2, img3, img4, chat_history

    # Bind the submit function to both the submit action of Textbox and the button click
    msg.submit(
        submit_message,
        [msg, img1, img2, img3, img4, chatbot],
        [msg, img1, img2, img3, img4, chatbot],
    )
    btn.click(
        submit_message,
        [msg, img1, img2, img3, img4, chatbot],
        [msg, img1, img2, img3, img4, chatbot],
    )

# Launch the Gradio interface
demo.launch(debug=True, share=True)