obichimav commited on
Commit
70a1336
Β·
verified Β·
1 Parent(s): f9ff8f6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # imports
2
+ import os
3
+ import json
4
+ import base64
5
+ from io import BytesIO
6
+ from dotenv import load_dotenv
7
+ from openai import OpenAI
8
+ import gradio as gr
9
+ import cv2
10
+ import numpy as np
11
+ from PIL import Image, ImageDraw
12
+ import requests
13
+ import torch
14
+ from transformers import (
15
+ AutoProcessor,
16
+ Owlv2ForObjectDetection,
17
+ AutoModelForZeroShotObjectDetection
18
+ )
19
+ # from transformers import AutoProcessor, Owlv2ForObjectDetection
20
+ from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
21
+
22
+ # Initialization
23
+ load_dotenv()
24
+ os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
25
+ PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
26
+ MODEL = "gpt-4o"
27
+ openai = OpenAI()
28
+
29
+ # Initialize models
30
+ device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ # Owlv2
32
+ owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
33
+ owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
34
+ # DINO
35
+ dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
36
+ dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
37
+
38
+ system_message = """You are an expert in object detection. When users mention:
39
+ 1. "count [object(s)]" - Use detect_objects with proper format based on model
40
+ 2. "detect [object(s)]" - Same as count
41
+ 3. "show [object(s)]" - Same as count
42
+
43
+ For DINO model: Format queries as "a [object]." (e.g., "a frog.")
44
+ For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]]
45
+
46
+ Always use object detection tool when counting/detecting is mentioned."""
47
+
48
+ system_message += "Always be accurate. If you don't know the answer, say so."
49
+
50
+
51
+ class State:
52
+ def __init__(self):
53
+ self.current_image = None
54
+ self.last_prediction = None
55
+ self.current_model = "owlv2" # Default model
56
+
57
+ state = State()
58
+
59
+ def get_preprocessed_image(pixel_values):
60
+ pixel_values = pixel_values.squeeze().numpy()
61
+ unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
62
+ unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
63
+ unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
64
+ return unnormalized_image
65
+
66
+ def encode_image_to_base64(image_array):
67
+ if image_array is None:
68
+ return None
69
+ image = Image.fromarray(image_array)
70
+ buffered = BytesIO()
71
+ image.save(buffered, format="JPEG")
72
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
73
+
74
+
75
+ def format_query_for_model(text_input, model_type="owlv2"):
76
+ """Format query based on model requirements"""
77
+ # Extract objects (e.g., "count frogs and horses" -> ["frog", "horse"])
78
+ text = text_input.lower()
79
+ words = [w.strip('.,?!') for w in text.split()
80
+ if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
81
+
82
+ if model_type == "owlv2":
83
+ return [["a photo of " + obj for obj in words]]
84
+ else: # DINO
85
+ # DINO only works with single object queries with format "a object."
86
+ return f"a {words[0]}."
87
+
88
+ def detect_objects(query_text):
89
+ if state.current_image is None:
90
+ return {"count": 0, "message": "No image provided"}
91
+
92
+ image = Image.fromarray(state.current_image)
93
+ draw = ImageDraw.Draw(image)
94
+
95
+ if state.current_model == "owlv2":
96
+ inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
97
+ with torch.no_grad():
98
+ outputs = owlv2_model(**inputs)
99
+ results = owlv2_processor.post_process_object_detection(
100
+ outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
101
+ )
102
+ else: # DINO
103
+ inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
104
+ with torch.no_grad():
105
+ outputs = dino_model(**inputs)
106
+ results = dino_processor.post_process_grounded_object_detection(
107
+ outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3,
108
+ target_sizes=[image.size[::-1]]
109
+ )
110
+
111
+ # Draw detection boxes
112
+ boxes = results[0]["boxes"]
113
+ scores = results[0]["scores"]
114
+
115
+ for box, score in zip(boxes, scores):
116
+ box = [round(i) for i in box.tolist()]
117
+ draw.rectangle(box, outline="red", width=3)
118
+ draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red")
119
+
120
+ state.last_prediction = np.array(image)
121
+ return {
122
+ "count": len(boxes),
123
+ "confidence": scores.tolist(),
124
+ "message": f"Detected {len(boxes)} objects"
125
+ }
126
+
127
+
128
+ def identify_plant():
129
+ if state.current_image is None:
130
+ return {"error": "No image provided"}
131
+
132
+ image = Image.fromarray(state.current_image)
133
+ img_byte_arr = BytesIO()
134
+ image.save(img_byte_arr, format='JPEG')
135
+ img_byte_arr = img_byte_arr.getvalue()
136
+
137
+ api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}"
138
+ files = [('images', ('image.jpg', img_byte_arr))]
139
+ data = {'organs': ['leaf']}
140
+
141
+ try:
142
+ response = requests.post(api_endpoint, files=files, data=data)
143
+ if response.status_code == 200:
144
+ result = response.json()
145
+ best_match = result['results'][0]
146
+ return {
147
+ "scientific_name": best_match['species']['scientificName'],
148
+ "common_names": best_match['species'].get('commonNames', []),
149
+ "family": best_match['species']['family']['scientificName'],
150
+ "genus": best_match['species']['genus']['scientificName'],
151
+ "confidence": f"{best_match['score']*100:.1f}%"
152
+ }
153
+ else:
154
+ return {"error": f"API Error: {response.status_code}"}
155
+ except Exception as e:
156
+ return {"error": f"Error: {str(e)}"}
157
+
158
+ # Tool definitions
159
+ object_detection_function = {
160
+ "name": "detect_objects",
161
+ "description": "Use this function to detect and count objects in images based on text queries.",
162
+ "parameters": {
163
+ "type": "object",
164
+ "properties": {
165
+ "query_text": {
166
+ "type": "array",
167
+ "description": "List of text queries describing objects to detect",
168
+ "items": {"type": "string"}
169
+ }
170
+ }
171
+ }
172
+ }
173
+
174
+ plant_identification_function = {
175
+ "name": "identify_plant",
176
+ "description": "Use this when asked about plant species identification or botanical classification.",
177
+ "parameters": {
178
+ "type": "object",
179
+ "properties": {},
180
+ "required": []
181
+ }
182
+ }
183
+
184
+ tools = [
185
+ {"type": "function", "function": object_detection_function},
186
+ {"type": "function", "function": plant_identification_function}
187
+ ]
188
+
189
+ def format_tool_response(tool_response_content):
190
+ data = json.loads(tool_response_content)
191
+ if "error" in data:
192
+ return f"Error: {data['error']}"
193
+ elif "scientific_name" in data:
194
+ return f"""πŸ“‹ Plant Identification Results:
195
+
196
+ 🌿 Scientific Name: {data['scientific_name']}
197
+ πŸ‘₯ Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'}
198
+ πŸ‘ͺ Family: {data['family']}
199
+ 🎯 Confidence: {data['confidence']}"""
200
+ else:
201
+ return f"I detected {data['count']} objects in the image."
202
+
203
+ def chat(message, image, history):
204
+ if image is not None:
205
+ state.current_image = image
206
+
207
+ if state.current_image is None:
208
+ return "Please upload an image first.", None
209
+
210
+ base64_image = encode_image_to_base64(state.current_image)
211
+ messages = [{"role": "system", "content": system_message}]
212
+
213
+ for human, assistant in history:
214
+ messages.append({"role": "user", "content": human})
215
+ messages.append({"role": "assistant", "content": assistant})
216
+
217
+ # Extract objects to detect from user message
218
+ # This could be enhanced with better NLP
219
+ objects_to_detect = message.lower()
220
+ formatted_query = format_query_for_model(objects_to_detect, state.current_model)
221
+
222
+ messages.append({
223
+ "role": "user",
224
+ "content": [
225
+ {"type": "text", "text": message},
226
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
227
+ ]
228
+ })
229
+
230
+ response = openai.chat.completions.create(
231
+ model=MODEL,
232
+ messages=messages,
233
+ tools=tools,
234
+ max_tokens=300
235
+ )
236
+
237
+ if response.choices[0].finish_reason == "tool_calls":
238
+ message = response.choices[0].message
239
+ messages.append(message)
240
+
241
+ for tool_call in message.tool_calls:
242
+ if tool_call.function.name == "detect_objects":
243
+ results = detect_objects(formatted_query)
244
+ else:
245
+ results = identify_plant()
246
+
247
+ tool_response = {
248
+ "role": "tool",
249
+ "content": json.dumps(results),
250
+ "tool_call_id": tool_call.id
251
+ }
252
+ messages.append(tool_response)
253
+
254
+ response = openai.chat.completions.create(
255
+ model=MODEL,
256
+ messages=messages,
257
+ max_tokens=300
258
+ )
259
+
260
+ return response.choices[0].message.content, state.last_prediction
261
+
262
+ def update_model(choice):
263
+ print(f"Model switched to: {choice}")
264
+ state.current_model = choice.lower()
265
+ return f"Model switched to {choice}"
266
+
267
+ # Create Gradio interface
268
+ with gr.Blocks() as demo:
269
+ gr.Markdown("# Object Detection and Plant Analysis System")
270
+
271
+ with gr.Row():
272
+ with gr.Column():
273
+ model_choice = gr.Radio(
274
+ choices=["Owlv2", "DINO"],
275
+ value="Owlv2",
276
+ label="Select Detection Model",
277
+ interactive=True
278
+ )
279
+ image_input = gr.Image(type="numpy", label="Upload Image")
280
+ text_input = gr.Textbox(
281
+ label="Ask about the image",
282
+ placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'"
283
+ )
284
+ with gr.Row():
285
+ submit_btn = gr.Button("Analyze")
286
+ reset_btn = gr.Button("Reset")
287
+
288
+ with gr.Column():
289
+ chatbot = gr.Chatbot()
290
+ # output_image = gr.Image(label="Detected Objects")
291
+ output_image = gr.Image(type="numpy", label="Detected Objects")
292
+
293
+ def process_interaction(message, image, history):
294
+ response, pred_image = chat(message, image, history)
295
+ history.append((message, response))
296
+ return "", pred_image, history
297
+
298
+ def reset_interface():
299
+ state.current_image = None
300
+ state.last_prediction = None
301
+ return None, None, None, []
302
+
303
+ model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)])
304
+
305
+ submit_btn.click(
306
+ fn=process_interaction,
307
+ inputs=[text_input, image_input, chatbot],
308
+ outputs=[text_input, output_image, chatbot]
309
+ )
310
+
311
+ reset_btn.click(
312
+ fn=reset_interface,
313
+ inputs=[],
314
+ outputs=[image_input, output_image, text_input, chatbot]
315
+ )
316
+
317
+ gr.Markdown("""## Instructions
318
+ 1. Select the detection model (Owlv2 or DINO)
319
+ 2. Upload an image
320
+ 3. Ask specific questions about objects or plants
321
+ 4. Click Analyze to get results""")
322
+
323
+ demo.launch(share=True)