TheM1N9 commited on
Commit
de15782
·
1 Parent(s): 775a8e1

added type annotions and doc strings

Browse files
Files changed (2) hide show
  1. .gitignore +56 -0
  2. app.py +81 -26
.gitignore ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ uploads
3
+ chroma
4
+ instance
5
+ .venv
6
+
7
+ # Byte-compiled / optimized / DLL files
8
+ __pycache__/
9
+ *.py[cod]
10
+ *$py.class
11
+
12
+ # C extensions
13
+ *.so
14
+
15
+ # Distribution / packaging
16
+ .Python
17
+ env/
18
+ build/
19
+ develop-eggs/
20
+ dist/
21
+ downloads/
22
+ eggs/
23
+ .eggs/
24
+ lib/
25
+ lib64/
26
+ parts/
27
+ sdist/
28
+ var/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *,cover
52
+ .hypothesis/
53
+ venv/
54
+ .python-version
55
+
56
+ *.log
app.py CHANGED
@@ -1,41 +1,68 @@
 
1
  import google.generativeai as genai
2
  from dotenv import load_dotenv
3
  import os
 
4
  import gradio as gr
5
  from PIL import Image
6
  import numpy as np
7
 
8
  load_dotenv()
9
 
10
- GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
11
  genai.configure(api_key=GOOGLE_API_KEY)
12
 
13
- def save_image(file_input, image_name):
 
 
 
 
 
 
 
 
 
 
14
  # Convert the input to a PIL image
15
- image_pil = Image.fromarray(np.uint8(file_input))
16
-
17
  # Define the directory where the image will be saved
18
  save_directory = "images"
19
-
20
  # Check if the directory exists, create it if not
21
  if not os.path.exists(save_directory):
22
  os.makedirs(save_directory, exist_ok=True)
23
-
24
  # Define the full path to save the image
25
- image_path = os.path.join(save_directory, image_name)
26
-
27
  # Save the image
28
  image_pil.save(image_path)
29
 
30
  return image_path
31
 
32
- def generate_response(text_input, file_inputs=None, chat_history=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Upload the files (images) and print a confirmation.
34
- image_paths = []
35
  if file_inputs is not None:
36
  for idx, file_input in enumerate(file_inputs):
37
- image_name = f"image_{idx + 1}.jpg"
38
- image_path = save_image(file_input, image_name)
39
  image_paths.append(image_path)
40
 
41
  # Choose a Gemini API model.
@@ -49,30 +76,37 @@ def generate_response(text_input, file_inputs=None, chat_history=None):
49
  chat_history_content = []
50
  for user_message, bot_response in chat_history:
51
  chat_history_content.append({"role": "user", "parts": [{"text": user_message}]})
52
- chat_history_content.append({"role": "model", "parts": [{"text": bot_response}]})
 
 
53
 
54
- chat = model.start_chat(history=chat_history_content)
55
 
56
  # Open images and pass them with text_input if available
57
- images = [Image.open(image_path) for image_path in image_paths] if image_paths else None
 
 
58
 
59
  # Prompt the model with text and the uploaded images if available
60
  if images:
61
- response = chat.send_message([*images, text_input])
62
  else:
63
- response = chat.send_message(text_input)
64
 
65
  # Append the new message to chat history in Gradio format (user, bot)
66
  chat_history.append((text_input, response.text))
67
 
68
- return "", chat_history
 
69
 
70
  # Create a Gradio interface with Blocks
71
  with gr.Blocks(title="Gemini vision") as demo:
72
  gr.Markdown("# Chat Bot M1N9")
73
 
74
  # Define the Chatbot component
75
- chatbot = gr.Chatbot([], elem_id="chatbot", height=700, show_share_button=True, show_copy_button=True)
 
 
76
 
77
  # Define the Textbox and Image components
78
  msg = gr.Textbox(show_copy_button=True, placeholder="Type your message here...")
@@ -83,28 +117,49 @@ with gr.Blocks(title="Gemini vision") as demo:
83
  img2 = gr.Image()
84
  img3 = gr.Image()
85
  img4 = gr.Image()
86
-
87
  btn = gr.Button("Submit")
88
 
89
  # Define the ClearButton component
90
  clear = gr.ClearButton([msg, img1, img2, img3, img4, chatbot])
91
 
92
  # Set the submit function for the Textbox and Image
93
- def submit_message(msg, img1, img2, img3, img4, chat_history):
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  # Collect all images into a list
95
  image_list = [img1, img2, img3, img4]
96
  # Filter out None values in case fewer than 4 images are uploaded
97
  image_list = [img for img in image_list if img is not None]
98
-
99
  # Call the generate_response with the list of images
100
  response, chat_history = generate_response(msg, image_list, chat_history)
101
-
102
  # Return the updated chat history and clear input fields
103
- return "", None, None, None, None, chat_history
104
 
105
  # Bind the submit function to both the submit action of Textbox and the button click
106
- msg.submit(submit_message, [msg, img1, img2, img3, img4, chatbot], [msg, img1, img2, img3, img4, chatbot])
107
- btn.click(submit_message, [msg, img1, img2, img3, img4, chatbot], [msg, img1, img2, img3, img4, chatbot])
 
 
 
 
 
 
 
 
108
 
109
  # Launch the Gradio interface
110
  demo.launch(debug=True, share=True)
 
1
+ from typing import Any, List, Optional, Tuple, Literal
2
  import google.generativeai as genai
3
  from dotenv import load_dotenv
4
  import os
5
+ from google.generativeai.types.generation_types import GenerateContentResponse
6
  import gradio as gr
7
  from PIL import Image
8
  import numpy as np
9
 
10
  load_dotenv()
11
 
12
+ GOOGLE_API_KEY: str = os.getenv("GOOGLE_API_KEY", "Enter correct API key")
13
  genai.configure(api_key=GOOGLE_API_KEY)
14
 
15
+
16
+ def save_image(file_input: str, image_name: str) -> str:
17
+ """Saves images into the memory.
18
+
19
+ Args:
20
+ file_input (str): file input from Gradio
21
+ image_name (str): file name to be saved
22
+
23
+ Returns:
24
+ str: path of the saved image
25
+ """
26
  # Convert the input to a PIL image
27
+ image_pil: Image.Image = Image.fromarray(np.uint8(file_input))
28
+
29
  # Define the directory where the image will be saved
30
  save_directory = "images"
31
+
32
  # Check if the directory exists, create it if not
33
  if not os.path.exists(save_directory):
34
  os.makedirs(save_directory, exist_ok=True)
35
+
36
  # Define the full path to save the image
37
+ image_path: str = os.path.join(save_directory, image_name)
38
+
39
  # Save the image
40
  image_pil.save(image_path)
41
 
42
  return image_path
43
 
44
+
45
+ def generate_response(
46
+ text_input: str,
47
+ file_inputs: Optional[List[str]] = None,
48
+ chat_history: Optional[List[Tuple[str, str]]] = None,
49
+ ) -> Tuple[str, Any | List[Any]]:
50
+ """Generates response using gemini-1.5-flash model.
51
+
52
+ Args:
53
+ text_input (str): user input
54
+ file_inputs (List[str], optional): file paths of the uploaded images. Defaults to None.
55
+ chat_history (List[Tuple[str, str]], optional): chat history of the user. Defaults to None.
56
+
57
+ Returns:
58
+ Tuple[str, Any | List[Any]]: returns response and chat history
59
+ """
60
  # Upload the files (images) and print a confirmation.
61
+ image_paths: List[str] = []
62
  if file_inputs is not None:
63
  for idx, file_input in enumerate(file_inputs):
64
+ image_name: str = f"image_{idx + 1}.jpg"
65
+ image_path: str = save_image(file_input, image_name)
66
  image_paths.append(image_path)
67
 
68
  # Choose a Gemini API model.
 
76
  chat_history_content = []
77
  for user_message, bot_response in chat_history:
78
  chat_history_content.append({"role": "user", "parts": [{"text": user_message}]})
79
+ chat_history_content.append(
80
+ {"role": "model", "parts": [{"text": bot_response}]}
81
+ )
82
 
83
+ chat: genai.ChatSession = model.start_chat(history=chat_history_content)
84
 
85
  # Open images and pass them with text_input if available
86
+ images = (
87
+ [Image.open(image_path) for image_path in image_paths] if image_paths else None
88
+ )
89
 
90
  # Prompt the model with text and the uploaded images if available
91
  if images:
92
+ response: GenerateContentResponse = chat.send_message([*images, text_input])
93
  else:
94
+ response: GenerateContentResponse = chat.send_message(text_input)
95
 
96
  # Append the new message to chat history in Gradio format (user, bot)
97
  chat_history.append((text_input, response.text))
98
 
99
+ return response.text, chat_history
100
+
101
 
102
  # Create a Gradio interface with Blocks
103
  with gr.Blocks(title="Gemini vision") as demo:
104
  gr.Markdown("# Chat Bot M1N9")
105
 
106
  # Define the Chatbot component
107
+ chatbot = gr.Chatbot(
108
+ [], elem_id="chatbot", height=700, show_share_button=True, show_copy_button=True
109
+ )
110
 
111
  # Define the Textbox and Image components
112
  msg = gr.Textbox(show_copy_button=True, placeholder="Type your message here...")
 
117
  img2 = gr.Image()
118
  img3 = gr.Image()
119
  img4 = gr.Image()
120
+
121
  btn = gr.Button("Submit")
122
 
123
  # Define the ClearButton component
124
  clear = gr.ClearButton([msg, img1, img2, img3, img4, chatbot])
125
 
126
  # Set the submit function for the Textbox and Image
127
+ def submit_message(msg: str, img1, img2, img3, img4, chat_history):
128
+ """Takes response from the generated response and displays it in the chatbot.
129
+
130
+ Args:
131
+ msg (str): user input
132
+ img1 (_type_): image input
133
+ img2 (_type_): image input
134
+ img3 (_type_): image input
135
+ img4 (_type_): image input
136
+ chat_history (_type_): chat history of the user
137
+
138
+ Returns:
139
+ _type_: _description_
140
+ """
141
  # Collect all images into a list
142
  image_list = [img1, img2, img3, img4]
143
  # Filter out None values in case fewer than 4 images are uploaded
144
  image_list = [img for img in image_list if img is not None]
145
+
146
  # Call the generate_response with the list of images
147
  response, chat_history = generate_response(msg, image_list, chat_history)
148
+
149
  # Return the updated chat history and clear input fields
150
+ return "", img1, img2, img3, img4, chat_history
151
 
152
  # Bind the submit function to both the submit action of Textbox and the button click
153
+ msg.submit(
154
+ submit_message,
155
+ [msg, img1, img2, img3, img4, chatbot],
156
+ [msg, img1, img2, img3, img4, chatbot],
157
+ )
158
+ btn.click(
159
+ submit_message,
160
+ [msg, img1, img2, img3, img4, chatbot],
161
+ [msg, img1, img2, img3, img4, chatbot],
162
+ )
163
 
164
  # Launch the Gradio interface
165
  demo.launch(debug=True, share=True)