obichimav commited on
Commit
f8cecaf
·
verified ·
1 Parent(s): f8358fa

Update app.py

Browse files

Updated code to vision-agent frame work

Files changed (1) hide show
  1. app.py +386 -106
app.py CHANGED
@@ -1,3 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # imports
2
  import os
3
  import json
@@ -9,39 +339,27 @@ import gradio as gr
9
  import numpy as np
10
  from PIL import Image, ImageDraw
11
  import requests
12
- import torch
13
- from transformers import (
14
- AutoProcessor,
15
- Owlv2ForObjectDetection,
16
- AutoModelForZeroShotObjectDetection
17
- )
18
- # from transformers import AutoProcessor, Owlv2ForObjectDetection
19
- from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
20
 
21
  # Initialization
22
  load_dotenv()
23
  os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
 
24
  PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
25
  MODEL = "gpt-4o"
26
  openai = OpenAI()
27
 
28
- # Initialize models
29
- device = "cuda" if torch.cuda.is_available() else "cpu"
30
- # Owlv2
31
- owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
32
- owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
33
- # DINO
34
- dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
35
- dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
36
 
37
  system_message = """You are an expert in object detection. When users mention:
38
- 1. "count [object(s)]" - Use detect_objects with proper format based on model
39
  2. "detect [object(s)]" - Same as count
40
  3. "show [object(s)]" - Same as count
41
 
42
- For DINO model: Format queries as "a [object]." (e.g., "a frog.")
43
- For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]]
44
-
45
  Always use object detection tool when counting/detecting is mentioned."""
46
 
47
  system_message += "Always be accurate. If you don't know the answer, say so."
@@ -51,17 +369,9 @@ class State:
51
  def __init__(self):
52
  self.current_image = None
53
  self.last_prediction = None
54
- self.current_model = "owlv2" # Default model
55
 
56
  state = State()
57
 
58
- def get_preprocessed_image(pixel_values):
59
- pixel_values = pixel_values.squeeze().numpy()
60
- unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
61
- unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
62
- unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
63
- return unnormalized_image
64
-
65
  def encode_image_to_base64(image_array):
66
  if image_array is None:
67
  return None
@@ -70,66 +380,44 @@ def encode_image_to_base64(image_array):
70
  image.save(buffered, format="JPEG")
71
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
72
 
73
-
74
- def format_query_for_model(text_input, model_type="owlv2"):
75
- """Format query based on model requirements"""
76
- # Extract objects (e.g., "detect a lion" -> "lion")
77
- text = text_input.lower()
78
- words = [w.strip('.,?!') for w in text.split()
79
- if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
80
-
81
- if model_type == "owlv2":
82
- # Return just the list of queries for Owlv2, not nested list
83
- queries = ["a photo of " + obj for obj in words]
84
- print("Owlv2 queries:", queries)
85
- return queries
86
- else: # DINO
87
- # DINO query format
88
- query = f"a {words[:]}."
89
- print("DINO query:", query)
90
- return query
91
-
92
 
93
  def detect_objects(query_text):
94
  if state.current_image is None:
95
  return {"count": 0, "message": "No image provided"}
96
 
97
- image = Image.fromarray(state.current_image)
98
- draw = ImageDraw.Draw(image)
99
-
100
- if state.current_model == "owlv2":
101
- # For Owlv2, pass the text queries directly
102
- inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
103
- with torch.no_grad():
104
- outputs = owlv2_model(**inputs)
105
- results = owlv2_processor.post_process_object_detection(
106
- outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
107
- )
108
- else: # DINO
109
- # For DINO, pass the single text query
110
- inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
111
- with torch.no_grad():
112
- outputs = dino_model(**inputs)
113
- results = dino_processor.post_process_grounded_object_detection(
114
- outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3,
115
- target_sizes=[image.size[::-1]]
116
- )
117
-
118
- # Draw detection boxes
119
- boxes = results[0]["boxes"]
120
- scores = results[0]["scores"]
121
-
122
- for box, score in zip(boxes, scores):
123
- box = [round(i) for i in box.tolist()]
124
- draw.rectangle(box, outline="red", width=3)
125
- draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red")
126
 
127
- state.last_prediction = np.array(image)
128
- return {
129
- "count": len(boxes),
130
- "confidence": scores.tolist(),
131
- "message": f"Detected {len(boxes)} objects"
132
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def identify_plant():
135
  if state.current_image is None:
@@ -221,9 +509,10 @@ def chat(message, image, history):
221
  messages.append({"role": "assistant", "content": assistant})
222
 
223
  # Extract objects to detect from user message
224
- # This could be enhanced with better NLP
225
  objects_to_detect = message.lower()
226
- formatted_query = format_query_for_model(objects_to_detect, state.current_model)
 
 
227
 
228
  messages.append({
229
  "role": "user",
@@ -246,7 +535,7 @@ def chat(message, image, history):
246
 
247
  for tool_call in message.tool_calls:
248
  if tool_call.function.name == "detect_objects":
249
- results = detect_objects(formatted_query)
250
  else:
251
  results = identify_plant()
252
 
@@ -265,27 +554,16 @@ def chat(message, image, history):
265
 
266
  return response.choices[0].message.content, state.last_prediction
267
 
268
- def update_model(choice):
269
- print(f"Model switched to: {choice}")
270
- state.current_model = choice.lower()
271
- return f"Model switched to {choice}"
272
-
273
  # Create Gradio interface
274
  with gr.Blocks() as demo:
275
- gr.Markdown("# Object Detection and Plant Analysis System")
276
 
277
  with gr.Row():
278
  with gr.Column():
279
- model_choice = gr.Radio(
280
- choices=["Owlv2", "DINO"],
281
- value="Owlv2",
282
- label="Select Detection Model",
283
- interactive=True
284
- )
285
  image_input = gr.Image(type="numpy", label="Upload Image")
286
  text_input = gr.Textbox(
287
  label="Ask about the image",
288
- placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'"
289
  )
290
  with gr.Row():
291
  submit_btn = gr.Button("Analyze")
@@ -293,8 +571,7 @@ with gr.Blocks() as demo:
293
 
294
  with gr.Column():
295
  chatbot = gr.Chatbot()
296
- # output_image = gr.Image(label="Detected Objects")
297
- output_image = gr.Image(type="numpy", label="Detected Objects")
298
 
299
  def process_interaction(message, image, history):
300
  response, pred_image = chat(message, image, history)
@@ -306,8 +583,6 @@ with gr.Blocks() as demo:
306
  state.last_prediction = None
307
  return None, None, None, []
308
 
309
- model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)])
310
-
311
  submit_btn.click(
312
  fn=process_interaction,
313
  inputs=[text_input, image_input, chatbot],
@@ -321,9 +596,14 @@ with gr.Blocks() as demo:
321
  )
322
 
323
  gr.Markdown("""## Instructions
324
- 1. Select the detection model (Owlv2 or DINO)
325
- 2. Upload an image
326
- 3. Ask specific questions about objects or plants
327
- 4. Click Analyze to get results""")
 
 
 
 
 
328
 
329
  demo.launch(share=True)
 
1
+ # # imports
2
+ # import os
3
+ # import json
4
+ # import base64
5
+ # from io import BytesIO
6
+ # from dotenv import load_dotenv
7
+ # from openai import OpenAI
8
+ # import gradio as gr
9
+ # import numpy as np
10
+ # from PIL import Image, ImageDraw
11
+ # import requests
12
+ # import torch
13
+ # from transformers import (
14
+ # AutoProcessor,
15
+ # Owlv2ForObjectDetection,
16
+ # AutoModelForZeroShotObjectDetection
17
+ # )
18
+ # # from transformers import AutoProcessor, Owlv2ForObjectDetection
19
+ # from transformers.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
20
+
21
+ # # Initialization
22
+ # load_dotenv()
23
+ # os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
24
+ # PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
25
+ # MODEL = "gpt-4o"
26
+ # openai = OpenAI()
27
+
28
+ # # Initialize models
29
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
30
+ # # Owlv2
31
+ # owlv2_processor = AutoProcessor.from_pretrained("google/owlv2-base-patch16")
32
+ # owlv2_model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16").to(device)
33
+ # # DINO
34
+ # dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
35
+ # dino_model = AutoModelForZeroShotObjectDetection.from_pretrained("IDEA-Research/grounding-dino-base").to(device)
36
+
37
+ # system_message = """You are an expert in object detection. When users mention:
38
+ # 1. "count [object(s)]" - Use detect_objects with proper format based on model
39
+ # 2. "detect [object(s)]" - Same as count
40
+ # 3. "show [object(s)]" - Same as count
41
+
42
+ # For DINO model: Format queries as "a [object]." (e.g., "a frog.")
43
+ # For Owlv2 model: Format as [["a photo of [object]", "a photo of [object2]"]]
44
+
45
+ # Always use object detection tool when counting/detecting is mentioned."""
46
+
47
+ # system_message += "Always be accurate. If you don't know the answer, say so."
48
+
49
+
50
+ # class State:
51
+ # def __init__(self):
52
+ # self.current_image = None
53
+ # self.last_prediction = None
54
+ # self.current_model = "owlv2" # Default model
55
+
56
+ # state = State()
57
+
58
+ # def get_preprocessed_image(pixel_values):
59
+ # pixel_values = pixel_values.squeeze().numpy()
60
+ # unnormalized_image = (pixel_values * np.array(OPENAI_CLIP_STD)[:, None, None]) + np.array(OPENAI_CLIP_MEAN)[:, None, None]
61
+ # unnormalized_image = (unnormalized_image * 255).astype(np.uint8)
62
+ # unnormalized_image = np.moveaxis(unnormalized_image, 0, -1)
63
+ # return unnormalized_image
64
+
65
+ # def encode_image_to_base64(image_array):
66
+ # if image_array is None:
67
+ # return None
68
+ # image = Image.fromarray(image_array)
69
+ # buffered = BytesIO()
70
+ # image.save(buffered, format="JPEG")
71
+ # return base64.b64encode(buffered.getvalue()).decode('utf-8')
72
+
73
+
74
+ # def format_query_for_model(text_input, model_type="owlv2"):
75
+ # """Format query based on model requirements"""
76
+ # # Extract objects (e.g., "detect a lion" -> "lion")
77
+ # text = text_input.lower()
78
+ # words = [w.strip('.,?!') for w in text.split()
79
+ # if w not in ['count', 'detect', 'show', 'me', 'the', 'and', 'a', 'an']]
80
+
81
+ # if model_type == "owlv2":
82
+ # # Return just the list of queries for Owlv2, not nested list
83
+ # queries = ["a photo of " + obj for obj in words]
84
+ # print("Owlv2 queries:", queries)
85
+ # return queries
86
+ # else: # DINO
87
+ # # DINO query format
88
+ # query = f"a {words[:]}."
89
+ # print("DINO query:", query)
90
+ # return query
91
+
92
+
93
+ # def detect_objects(query_text):
94
+ # if state.current_image is None:
95
+ # return {"count": 0, "message": "No image provided"}
96
+
97
+ # image = Image.fromarray(state.current_image)
98
+ # draw = ImageDraw.Draw(image)
99
+
100
+ # if state.current_model == "owlv2":
101
+ # # For Owlv2, pass the text queries directly
102
+ # inputs = owlv2_processor(text=query_text, images=image, return_tensors="pt").to(device)
103
+ # with torch.no_grad():
104
+ # outputs = owlv2_model(**inputs)
105
+ # results = owlv2_processor.post_process_object_detection(
106
+ # outputs=outputs, threshold=0.2, target_sizes=torch.Tensor([image.size[::-1]])
107
+ # )
108
+ # else: # DINO
109
+ # # For DINO, pass the single text query
110
+ # inputs = dino_processor(images=image, text=query_text, return_tensors="pt").to(device)
111
+ # with torch.no_grad():
112
+ # outputs = dino_model(**inputs)
113
+ # results = dino_processor.post_process_grounded_object_detection(
114
+ # outputs, inputs.input_ids, box_threshold=0.1, text_threshold=0.3,
115
+ # target_sizes=[image.size[::-1]]
116
+ # )
117
+
118
+ # # Draw detection boxes
119
+ # boxes = results[0]["boxes"]
120
+ # scores = results[0]["scores"]
121
+
122
+ # for box, score in zip(boxes, scores):
123
+ # box = [round(i) for i in box.tolist()]
124
+ # draw.rectangle(box, outline="red", width=3)
125
+ # draw.text((box[0], box[1]), f"Score: {score:.2f}", fill="red")
126
+
127
+ # state.last_prediction = np.array(image)
128
+ # return {
129
+ # "count": len(boxes),
130
+ # "confidence": scores.tolist(),
131
+ # "message": f"Detected {len(boxes)} objects"
132
+ # }
133
+
134
+ # def identify_plant():
135
+ # if state.current_image is None:
136
+ # return {"error": "No image provided"}
137
+
138
+ # image = Image.fromarray(state.current_image)
139
+ # img_byte_arr = BytesIO()
140
+ # image.save(img_byte_arr, format='JPEG')
141
+ # img_byte_arr = img_byte_arr.getvalue()
142
+
143
+ # api_endpoint = f"https://my-api.plantnet.org/v2/identify/all?api-key={PLANTNET_API_KEY}"
144
+ # files = [('images', ('image.jpg', img_byte_arr))]
145
+ # data = {'organs': ['leaf']}
146
+
147
+ # try:
148
+ # response = requests.post(api_endpoint, files=files, data=data)
149
+ # if response.status_code == 200:
150
+ # result = response.json()
151
+ # best_match = result['results'][0]
152
+ # return {
153
+ # "scientific_name": best_match['species']['scientificName'],
154
+ # "common_names": best_match['species'].get('commonNames', []),
155
+ # "family": best_match['species']['family']['scientificName'],
156
+ # "genus": best_match['species']['genus']['scientificName'],
157
+ # "confidence": f"{best_match['score']*100:.1f}%"
158
+ # }
159
+ # else:
160
+ # return {"error": f"API Error: {response.status_code}"}
161
+ # except Exception as e:
162
+ # return {"error": f"Error: {str(e)}"}
163
+
164
+ # # Tool definitions
165
+ # object_detection_function = {
166
+ # "name": "detect_objects",
167
+ # "description": "Use this function to detect and count objects in images based on text queries.",
168
+ # "parameters": {
169
+ # "type": "object",
170
+ # "properties": {
171
+ # "query_text": {
172
+ # "type": "array",
173
+ # "description": "List of text queries describing objects to detect",
174
+ # "items": {"type": "string"}
175
+ # }
176
+ # }
177
+ # }
178
+ # }
179
+
180
+ # plant_identification_function = {
181
+ # "name": "identify_plant",
182
+ # "description": "Use this when asked about plant species identification or botanical classification.",
183
+ # "parameters": {
184
+ # "type": "object",
185
+ # "properties": {},
186
+ # "required": []
187
+ # }
188
+ # }
189
+
190
+ # tools = [
191
+ # {"type": "function", "function": object_detection_function},
192
+ # {"type": "function", "function": plant_identification_function}
193
+ # ]
194
+
195
+ # def format_tool_response(tool_response_content):
196
+ # data = json.loads(tool_response_content)
197
+ # if "error" in data:
198
+ # return f"Error: {data['error']}"
199
+ # elif "scientific_name" in data:
200
+ # return f"""📋 Plant Identification Results:
201
+
202
+ # 🌿 Scientific Name: {data['scientific_name']}
203
+ # 👥 Common Names: {', '.join(data['common_names']) if data['common_names'] else 'Not available'}
204
+ # 👪 Family: {data['family']}
205
+ # 🎯 Confidence: {data['confidence']}"""
206
+ # else:
207
+ # return f"I detected {data['count']} objects in the image."
208
+
209
+ # def chat(message, image, history):
210
+ # if image is not None:
211
+ # state.current_image = image
212
+
213
+ # if state.current_image is None:
214
+ # return "Please upload an image first.", None
215
+
216
+ # base64_image = encode_image_to_base64(state.current_image)
217
+ # messages = [{"role": "system", "content": system_message}]
218
+
219
+ # for human, assistant in history:
220
+ # messages.append({"role": "user", "content": human})
221
+ # messages.append({"role": "assistant", "content": assistant})
222
+
223
+ # # Extract objects to detect from user message
224
+ # # This could be enhanced with better NLP
225
+ # objects_to_detect = message.lower()
226
+ # formatted_query = format_query_for_model(objects_to_detect, state.current_model)
227
+
228
+ # messages.append({
229
+ # "role": "user",
230
+ # "content": [
231
+ # {"type": "text", "text": message},
232
+ # {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
233
+ # ]
234
+ # })
235
+
236
+ # response = openai.chat.completions.create(
237
+ # model=MODEL,
238
+ # messages=messages,
239
+ # tools=tools,
240
+ # max_tokens=300
241
+ # )
242
+
243
+ # if response.choices[0].finish_reason == "tool_calls":
244
+ # message = response.choices[0].message
245
+ # messages.append(message)
246
+
247
+ # for tool_call in message.tool_calls:
248
+ # if tool_call.function.name == "detect_objects":
249
+ # results = detect_objects(formatted_query)
250
+ # else:
251
+ # results = identify_plant()
252
+
253
+ # tool_response = {
254
+ # "role": "tool",
255
+ # "content": json.dumps(results),
256
+ # "tool_call_id": tool_call.id
257
+ # }
258
+ # messages.append(tool_response)
259
+
260
+ # response = openai.chat.completions.create(
261
+ # model=MODEL,
262
+ # messages=messages,
263
+ # max_tokens=300
264
+ # )
265
+
266
+ # return response.choices[0].message.content, state.last_prediction
267
+
268
+ # def update_model(choice):
269
+ # print(f"Model switched to: {choice}")
270
+ # state.current_model = choice.lower()
271
+ # return f"Model switched to {choice}"
272
+
273
+ # # Create Gradio interface
274
+ # with gr.Blocks() as demo:
275
+ # gr.Markdown("# Object Detection and Plant Analysis System")
276
+
277
+ # with gr.Row():
278
+ # with gr.Column():
279
+ # model_choice = gr.Radio(
280
+ # choices=["Owlv2", "DINO"],
281
+ # value="Owlv2",
282
+ # label="Select Detection Model",
283
+ # interactive=True
284
+ # )
285
+ # image_input = gr.Image(type="numpy", label="Upload Image")
286
+ # text_input = gr.Textbox(
287
+ # label="Ask about the image",
288
+ # placeholder="e.g., 'What objects do you see?' or 'What species is this plant?'"
289
+ # )
290
+ # with gr.Row():
291
+ # submit_btn = gr.Button("Analyze")
292
+ # reset_btn = gr.Button("Reset")
293
+
294
+ # with gr.Column():
295
+ # chatbot = gr.Chatbot()
296
+ # # output_image = gr.Image(label="Detected Objects")
297
+ # output_image = gr.Image(type="numpy", label="Detected Objects")
298
+
299
+ # def process_interaction(message, image, history):
300
+ # response, pred_image = chat(message, image, history)
301
+ # history.append((message, response))
302
+ # return "", pred_image, history
303
+
304
+ # def reset_interface():
305
+ # state.current_image = None
306
+ # state.last_prediction = None
307
+ # return None, None, None, []
308
+
309
+ # model_choice.change(fn=update_model, inputs=[model_choice], outputs=[gr.Textbox(visible=False)])
310
+
311
+ # submit_btn.click(
312
+ # fn=process_interaction,
313
+ # inputs=[text_input, image_input, chatbot],
314
+ # outputs=[text_input, output_image, chatbot]
315
+ # )
316
+
317
+ # reset_btn.click(
318
+ # fn=reset_interface,
319
+ # inputs=[],
320
+ # outputs=[image_input, output_image, text_input, chatbot]
321
+ # )
322
+
323
+ # gr.Markdown("""## Instructions
324
+ # 1. Select the detection model (Owlv2 or DINO)
325
+ # 2. Upload an image
326
+ # 3. Ask specific questions about objects or plants
327
+ # 4. Click Analyze to get results""")
328
+
329
+ # demo.launch(share=True)
330
+
331
  # imports
332
  import os
333
  import json
 
339
  import numpy as np
340
  from PIL import Image, ImageDraw
341
  import requests
342
+ import matplotlib.pyplot as plt
343
+ from vision_agent.agent import VisionAgentCoderV2
344
+ from vision_agent.models import AgentMessage
345
+ import vision_agent.tools as T
 
 
 
 
346
 
347
  # Initialization
348
  load_dotenv()
349
  os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-here')
350
+ os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-anthropic-key-here')
351
  PLANTNET_API_KEY = os.getenv('PLANTNET_API_KEY', 'your-plantnet-key-here')
352
  MODEL = "gpt-4o"
353
  openai = OpenAI()
354
 
355
+ # Initialize VisionAgent
356
+ agent = VisionAgentCoderV2(verbose=False)
 
 
 
 
 
 
357
 
358
  system_message = """You are an expert in object detection. When users mention:
359
+ 1. "count [object(s)]" - Use detect_objects to count them
360
  2. "detect [object(s)]" - Same as count
361
  3. "show [object(s)]" - Same as count
362
 
 
 
 
363
  Always use object detection tool when counting/detecting is mentioned."""
364
 
365
  system_message += "Always be accurate. If you don't know the answer, say so."
 
369
  def __init__(self):
370
  self.current_image = None
371
  self.last_prediction = None
 
372
 
373
  state = State()
374
 
 
 
 
 
 
 
 
375
  def encode_image_to_base64(image_array):
376
  if image_array is None:
377
  return None
 
380
  image.save(buffered, format="JPEG")
381
  return base64.b64encode(buffered.getvalue()).decode('utf-8')
382
 
383
+ def save_temp_image(image_array):
384
+ """Save the image to a temporary file for VisionAgent to process"""
385
+ temp_path = "temp_image.jpg"
386
+ image = Image.fromarray(image_array)
387
+ image.save(temp_path)
388
+ return temp_path
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
  def detect_objects(query_text):
391
  if state.current_image is None:
392
  return {"count": 0, "message": "No image provided"}
393
 
394
+ # Save the current image to a temporary file
395
+ image_path = save_temp_image(state.current_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
 
397
+ try:
398
+ # Use VisionAgent to detect objects
399
+ image = T.load_image(image_path)
400
+
401
+ # Clean query text to get the object name
402
+ object_name = query_text[0].replace("a photo of ", "").strip()
403
+
404
+ # Detect objects using CountGD
405
+ detections = T.countgd_object_detection(object_name, image)
406
+
407
+ # Visualize results
408
+ result_image = T.overlay_bounding_boxes(image, detections)
409
+
410
+ # Convert result back to numpy array for display
411
+ state.last_prediction = np.array(result_image)
412
+
413
+ return {
414
+ "count": len(detections),
415
+ "confidence": [det["score"] for det in detections],
416
+ "message": f"Detected {len(detections)} {object_name}(s)"
417
+ }
418
+ except Exception as e:
419
+ print(f"Error in detect_objects: {str(e)}")
420
+ return {"count": 0, "message": f"Error: {str(e)}"}
421
 
422
  def identify_plant():
423
  if state.current_image is None:
 
509
  messages.append({"role": "assistant", "content": assistant})
510
 
511
  # Extract objects to detect from user message
 
512
  objects_to_detect = message.lower()
513
+
514
+ # Format query for object detection
515
+ query = ["a photo of " + objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()]
516
 
517
  messages.append({
518
  "role": "user",
 
535
 
536
  for tool_call in message.tool_calls:
537
  if tool_call.function.name == "detect_objects":
538
+ results = detect_objects(query)
539
  else:
540
  results = identify_plant()
541
 
 
554
 
555
  return response.choices[0].message.content, state.last_prediction
556
 
 
 
 
 
 
557
  # Create Gradio interface
558
  with gr.Blocks() as demo:
559
+ gr.Markdown("# Object Detection and Plant Analysis System using VisionAgent")
560
 
561
  with gr.Row():
562
  with gr.Column():
 
 
 
 
 
 
563
  image_input = gr.Image(type="numpy", label="Upload Image")
564
  text_input = gr.Textbox(
565
  label="Ask about the image",
566
+ placeholder="e.g., 'Count dogs in this image' or 'What species is this plant?'"
567
  )
568
  with gr.Row():
569
  submit_btn = gr.Button("Analyze")
 
571
 
572
  with gr.Column():
573
  chatbot = gr.Chatbot()
574
+ output_image = gr.Image(type="numpy", label="Detection Results")
 
575
 
576
  def process_interaction(message, image, history):
577
  response, pred_image = chat(message, image, history)
 
583
  state.last_prediction = None
584
  return None, None, None, []
585
 
 
 
586
  submit_btn.click(
587
  fn=process_interaction,
588
  inputs=[text_input, image_input, chatbot],
 
596
  )
597
 
598
  gr.Markdown("""## Instructions
599
+ 1. Upload an image
600
+ 2. Ask specific questions about objects or plants
601
+ 3. Click Analyze to get results
602
+
603
+ Examples:
604
+ - "Count the number of people in this image"
605
+ - "Detect cats and dogs"
606
+ - "What species is this plant?"
607
+ """)
608
 
609
  demo.launch(share=True)