saq1b commited on
Commit
67811fe
·
verified ·
1 Parent(s): 53f2aac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -139
app.py CHANGED
@@ -1,139 +1,139 @@
1
- import google.generativeai as genai
2
- from google.generativeai.types import HarmBlockThreshold, HarmCategory
3
- import gradio as gr
4
- from PIL import Image, ImageDraw, ImageFont
5
- import json
6
-
7
- # Fetch bounding boxes and labels
8
- async def get_bounding_boxes(prompt: str, image: str, api_key: str):
9
- system_prompt = """
10
- You are a helpful assistant, who always responds with the bounding box and label with the explanation JSON based on the user input, and nothing else.
11
- Your response can also include multiple bounding boxes and their labels in the list.
12
- The values in the list should be integers.
13
- Here are some example responses:
14
- {
15
- "explanation": "User asked for the bounding box of the dragon, so I will provide the bounding box of the dragon.",
16
- "bounding_boxes": [
17
- {"label": "dragon", "box": [ymin, xmin, ymax, xmax]}
18
- ]
19
- }
20
- {
21
- "explanation": "User asked for the bounding box of the fruits which are red in color, so I will provide the bounding box of the Apple and the Tomato.",
22
- "bounding_boxes": [
23
- {"label": "apple", "box": [ymin, xmin, ymax, xmax]},
24
- {"label": "tomato", "box": [ymin, xmin, ymax, xmax]}
25
- ]
26
- }
27
- """.strip()
28
-
29
- prompt = f"Return the bounding boxes and labels of: {prompt}"
30
-
31
- messages = [
32
- {"role": "user", "parts": [prompt, image]},
33
- ]
34
-
35
- genai.configure(api_key=api_key)
36
-
37
- generation_config = {
38
- "temperature": 1,
39
- "max_output_tokens": 8192,
40
- "response_mime_type": "application/json",
41
- }
42
-
43
- model = genai.GenerativeModel(
44
- model_name="gemini-1.5-flash",
45
- generation_config=generation_config,
46
- safety_settings={
47
- HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
48
- HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
49
- HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
50
- HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE
51
- },
52
- system_instruction=system_prompt
53
- )
54
-
55
- try:
56
- response = await model.generate_content_async(messages)
57
- except Exception as e:
58
- if "API key not valid" in str(e):
59
- raise gr.Error(
60
- "Invalid API key. Please provide a valid Gemini API key.")
61
- elif "rate limit" in str(e).lower():
62
- raise gr.Error("Rate limit exceeded for the API key.")
63
- else:
64
- raise gr.Error(f"Failed to generate content: {str(e)}")
65
-
66
- response_json = json.loads(response.text)
67
-
68
- explanation = response_json["explanation"]
69
- bounding_boxes = response_json["bounding_boxes"]
70
-
71
- return bounding_boxes, explanation
72
-
73
- # Adjust bounding boxes based on image size
74
- async def adjust_bounding_box(bounding_boxes, image):
75
- width, height = image.size
76
- adjusted_boxes = []
77
- for item in bounding_boxes:
78
- label = item["label"]
79
- ymin, xmin, ymax, xmax = [coord / 1000 for coord in item["box"]]
80
- xmin *= width
81
- xmax *= width
82
- ymin *= height
83
- ymax *= height
84
- adjusted_boxes.append({"label": label, "box": [xmin, ymin, xmax, ymax]})
85
- return adjusted_boxes
86
-
87
- # Process the image and draw bounding boxes and labels
88
- async def process_image(image, text, api_key):
89
- if not api_key:
90
- raise gr.Error("Please provide a Gemini API key.")
91
-
92
- # Open the image using PIL
93
- image = Image.open(image)
94
-
95
- # Call the async bounding box function
96
- bounding_boxes, explanation = await get_bounding_boxes(text, image, api_key)
97
-
98
- # Adjust the bounding box based on the image dimensions
99
- adjusted_boxes = await adjust_bounding_box(bounding_boxes, image)
100
-
101
- # Draw the bounding boxes and labels on the image
102
- draw = ImageDraw.Draw(image)
103
- font = ImageFont.load_default(size=20)
104
-
105
- for item in adjusted_boxes:
106
- box = item["box"]
107
- label = item["label"]
108
- draw.rectangle(box, outline="red", width=3)
109
- # Draw the label above the bounding box
110
- draw.text((box[0], box[1] - 25), label, fill="red", font=font)
111
-
112
- # Format adjusted boxes for display
113
- adjusted_boxes_str = "\n".join(f"{item['label']}: {item['box']}" for item in adjusted_boxes)
114
-
115
- return explanation, image, adjusted_boxes_str
116
-
117
- # Gradio app
118
- async def gradio_app(image, text, api_key):
119
- return await process_image(image, text, api_key)
120
-
121
- # Launch the Gradio interface
122
- iface = gr.Interface(
123
- fn=gradio_app,
124
- inputs=[
125
- gr.Image(type="filepath"),
126
- gr.Textbox(label="Object(s) to detect", value="person"),
127
- gr.Textbox(label="Your Gemini API Key", type="password")
128
- ],
129
- outputs=[
130
- gr.Textbox(label="Explanation"),
131
- gr.Image(type="pil", label="Output Image"),
132
- gr.Textbox(label="Coordinates and Labels of the Bounding Box(es)")
133
- ],
134
- title="Gemini Object Detection ✨",
135
- description="Detect objects in images using the Gemini 1.5 Flash model.",
136
- allow_flagging="never"
137
- )
138
-
139
- iface.launch()
 
1
+ import google.generativeai as genai
2
+ from google.generativeai.types import HarmBlockThreshold, HarmCategory
3
+ import gradio as gr
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import json
6
+
7
+ # Fetch bounding boxes and labels
8
+ async def get_bounding_boxes(prompt: str, image: str, api_key: str):
9
+ system_prompt = """
10
+ You are a helpful assistant, who always responds with the bounding box and label with the explanation JSON based on the user input, and nothing else.
11
+ Your response can also include multiple bounding boxes and their labels in the list.
12
+ The values in the list should be integers.
13
+ Here are some example responses:
14
+ {
15
+ "explanation": "User asked for the bounding box of the dragon, so I will provide the bounding box of the dragon.",
16
+ "bounding_boxes": [
17
+ {"label": "dragon", "box": [ymin, xmin, ymax, xmax]}
18
+ ]
19
+ }
20
+ {
21
+ "explanation": "User asked for the bounding box of the fruits which are red in color, so I will provide the bounding box of the Apple and the Tomato.",
22
+ "bounding_boxes": [
23
+ {"label": "apple", "box": [ymin, xmin, ymax, xmax]},
24
+ {"label": "tomato", "box": [ymin, xmin, ymax, xmax]}
25
+ ]
26
+ }
27
+ """.strip()
28
+
29
+ prompt = f"Return the bounding boxes and labels of: {prompt}"
30
+
31
+ messages = [
32
+ {"role": "user", "parts": [prompt, image]},
33
+ ]
34
+
35
+ genai.configure(api_key=api_key)
36
+
37
+ generation_config = {
38
+ "temperature": 1,
39
+ "max_output_tokens": 8192,
40
+ "response_mime_type": "application/json",
41
+ }
42
+
43
+ model = genai.GenerativeModel(
44
+ model_name="gemini-1.5-flash",
45
+ generation_config=generation_config,
46
+ safety_settings={
47
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
48
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
49
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
50
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE
51
+ },
52
+ system_instruction=system_prompt
53
+ )
54
+
55
+ try:
56
+ response = await model.generate_content_async(messages)
57
+ except Exception as e:
58
+ if "API key not valid" in str(e):
59
+ raise gr.Error(
60
+ "Invalid API key. Please provide a valid Gemini API key.")
61
+ elif "rate limit" in str(e).lower():
62
+ raise gr.Error("Rate limit exceeded for the API key.")
63
+ else:
64
+ raise gr.Error(f"Failed to generate content: {str(e)}")
65
+
66
+ response_json = json.loads(response.text)
67
+
68
+ explanation = response_json["explanation"]
69
+ bounding_boxes = response_json["bounding_boxes"]
70
+
71
+ return bounding_boxes, explanation
72
+
73
+ # Adjust bounding boxes based on image size
74
+ async def adjust_bounding_box(bounding_boxes, image):
75
+ width, height = image.size
76
+ adjusted_boxes = []
77
+ for item in bounding_boxes:
78
+ label = item["label"]
79
+ ymin, xmin, ymax, xmax = [coord / 1000 for coord in item["box"]]
80
+ xmin *= width
81
+ xmax *= width
82
+ ymin *= height
83
+ ymax *= height
84
+ adjusted_boxes.append({"label": label, "box": [xmin, ymin, xmax, ymax]})
85
+ return adjusted_boxes
86
+
87
+ # Process the image and draw bounding boxes and labels
88
+ async def process_image(image, text, api_key):
89
+ if not api_key:
90
+ raise gr.Error("Please provide a Gemini API key.")
91
+
92
+ # Open the image using PIL
93
+ image = Image.open(image)
94
+
95
+ # Call the async bounding box function
96
+ bounding_boxes, explanation = await get_bounding_boxes(text, image, api_key)
97
+
98
+ # Adjust the bounding box based on the image dimensions
99
+ adjusted_boxes = await adjust_bounding_box(bounding_boxes, image)
100
+
101
+ # Draw the bounding boxes and labels on the image
102
+ draw = ImageDraw.Draw(image)
103
+ font = ImageFont.load_default(size=20)
104
+
105
+ for item in adjusted_boxes:
106
+ box = item["box"]
107
+ label = item["label"]
108
+ draw.rectangle(box, outline="red", width=3)
109
+ # Draw the label above the bounding box
110
+ draw.text((box[0], box[1] - 25), label, fill="red", font=font)
111
+
112
+ # Format adjusted boxes for display
113
+ adjusted_boxes_str = "\n".join(f"{item['label']}: {item['box']}" for item in adjusted_boxes)
114
+
115
+ return explanation, image, adjusted_boxes_str
116
+
117
+ # Gradio app
118
+ async def gradio_app(image, text, api_key):
119
+ return await process_image(image, text, api_key)
120
+
121
+ # Launch the Gradio interface
122
+ iface = gr.Interface(
123
+ fn=gradio_app,
124
+ inputs=[
125
+ gr.Image(type="filepath"),
126
+ gr.Textbox(label="Object(s) to detect", value="person"),
127
+ gr.Textbox(label="Your Gemini API Key", type="password")
128
+ ],
129
+ outputs=[
130
+ gr.Textbox(label="Explanation"),
131
+ gr.Image(type="pil", label="Output Image"),
132
+ gr.Textbox(label="Coordinates of the detected objects")
133
+ ],
134
+ title="Gemini Object Detection ✨",
135
+ description="Detect objects in images using the Gemini 1.5 Flash model.",
136
+ allow_flagging="never"
137
+ )
138
+
139
+ iface.launch()