Update app.py
Browse files
app.py
CHANGED
@@ -328,6 +328,7 @@
|
|
328 |
|
329 |
# demo.launch(share=True)
|
330 |
|
|
|
331 |
# imports
|
332 |
import os
|
333 |
import json
|
@@ -395,25 +396,54 @@ def detect_objects(query_text):
|
|
395 |
image_path = save_temp_image(state.current_image)
|
396 |
|
397 |
try:
|
398 |
-
# Use VisionAgent to detect objects
|
399 |
-
image = T.load_image(image_path)
|
400 |
-
|
401 |
# Clean query text to get the object name
|
402 |
object_name = query_text[0].replace("a photo of ", "").strip()
|
403 |
|
404 |
-
#
|
405 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
406 |
|
407 |
-
#
|
408 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
# Convert result back to numpy array for display
|
411 |
state.last_prediction = np.array(result_image)
|
412 |
|
413 |
return {
|
414 |
-
"count": len(
|
415 |
-
"confidence": [det["score"] for det in
|
416 |
-
"message": f"Detected {len(
|
417 |
}
|
418 |
except Exception as e:
|
419 |
print(f"Error in detect_objects: {str(e)}")
|
@@ -511,8 +541,9 @@ def chat(message, image, history):
|
|
511 |
# Extract objects to detect from user message
|
512 |
objects_to_detect = message.lower()
|
513 |
|
514 |
-
# Format query for object detection
|
515 |
-
|
|
|
516 |
|
517 |
messages.append({
|
518 |
"role": "user",
|
|
|
328 |
|
329 |
# demo.launch(share=True)
|
330 |
|
331 |
+
# imports
|
332 |
# imports
|
333 |
import os
|
334 |
import json
|
|
|
396 |
image_path = save_temp_image(state.current_image)
|
397 |
|
398 |
try:
|
|
|
|
|
|
|
399 |
# Clean query text to get the object name
|
400 |
object_name = query_text[0].replace("a photo of ", "").strip()
|
401 |
|
402 |
+
# Let VisionAgent handle the detection with its agent-based approach
|
403 |
+
# Create agent message for object detection
|
404 |
+
agent_message = [
|
405 |
+
AgentMessage(
|
406 |
+
role="user",
|
407 |
+
content=f"Count the number of {object_name} in this image. Only show detections with high confidence (>0.75).",
|
408 |
+
media=[image_path]
|
409 |
+
)
|
410 |
+
]
|
411 |
+
|
412 |
+
# Generate code using VisionAgent
|
413 |
+
code_context = agent.generate_code(agent_message)
|
414 |
|
415 |
+
# Load the image for visualization
|
416 |
+
image = T.load_image(image_path)
|
417 |
+
|
418 |
+
# Use multiple models for detection and get high confidence results
|
419 |
+
# First try the specialized detector
|
420 |
+
detections = T.countgd_object_detection(object_name, image, conf_threshold=0.75)
|
421 |
+
|
422 |
+
# If no high-confidence detections, try the more general object detector
|
423 |
+
if not detections:
|
424 |
+
# Try a different model with the same high threshold
|
425 |
+
try:
|
426 |
+
detections = T.grounding_dino_detection(object_name, image, box_threshold=0.75)
|
427 |
+
except:
|
428 |
+
pass
|
429 |
+
|
430 |
+
# Only keep high confidence detections
|
431 |
+
high_conf_detections = [det for det in detections if det.get("score", 0) > 0.75]
|
432 |
+
|
433 |
+
# Visualize only high confidence results with clear labeling
|
434 |
+
result_image = T.overlay_bounding_boxes(
|
435 |
+
image,
|
436 |
+
high_conf_detections,
|
437 |
+
labels=[f"{object_name}: {det['score']:.2f}" for det in high_conf_detections]
|
438 |
+
)
|
439 |
|
440 |
# Convert result back to numpy array for display
|
441 |
state.last_prediction = np.array(result_image)
|
442 |
|
443 |
return {
|
444 |
+
"count": len(high_conf_detections),
|
445 |
+
"confidence": [det["score"] for det in high_conf_detections],
|
446 |
+
"message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>0.75)"
|
447 |
}
|
448 |
except Exception as e:
|
449 |
print(f"Error in detect_objects: {str(e)}")
|
|
|
541 |
# Extract objects to detect from user message
|
542 |
objects_to_detect = message.lower()
|
543 |
|
544 |
+
# Format query for object detection - keep it simple and direct
|
545 |
+
cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
|
546 |
+
query = ["a photo of " + cleaned_query]
|
547 |
|
548 |
messages.append({
|
549 |
"role": "user",
|