obichimav commited on
Commit
0f05e4c
·
verified ·
1 Parent(s): 43dab8d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -12
app.py CHANGED
@@ -328,6 +328,7 @@
328
 
329
  # demo.launch(share=True)
330
 
 
331
  # imports
332
  import os
333
  import json
@@ -395,25 +396,54 @@ def detect_objects(query_text):
395
  image_path = save_temp_image(state.current_image)
396
 
397
  try:
398
- # Use VisionAgent to detect objects
399
- image = T.load_image(image_path)
400
-
401
  # Clean query text to get the object name
402
  object_name = query_text[0].replace("a photo of ", "").strip()
403
 
404
- # Detect objects using CountGD
405
- detections = T.countgd_object_detection(object_name, image)
 
 
 
 
 
 
 
 
 
 
406
 
407
- # Visualize results
408
- result_image = T.overlay_bounding_boxes(image, detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  # Convert result back to numpy array for display
411
  state.last_prediction = np.array(result_image)
412
 
413
  return {
414
- "count": len(detections),
415
- "confidence": [det["score"] for det in detections],
416
- "message": f"Detected {len(detections)} {object_name}(s)"
417
  }
418
  except Exception as e:
419
  print(f"Error in detect_objects: {str(e)}")
@@ -511,8 +541,9 @@ def chat(message, image, history):
511
  # Extract objects to detect from user message
512
  objects_to_detect = message.lower()
513
 
514
- # Format query for object detection
515
- query = ["a photo of " + objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()]
 
516
 
517
  messages.append({
518
  "role": "user",
 
328
 
329
  # demo.launch(share=True)
330
 
331
+ # imports
332
  # imports
333
  import os
334
  import json
 
396
  image_path = save_temp_image(state.current_image)
397
 
398
  try:
 
 
 
399
  # Clean query text to get the object name
400
  object_name = query_text[0].replace("a photo of ", "").strip()
401
 
402
+ # Let VisionAgent handle the detection with its agent-based approach
403
+ # Create agent message for object detection
404
+ agent_message = [
405
+ AgentMessage(
406
+ role="user",
407
+ content=f"Count the number of {object_name} in this image. Only show detections with high confidence (>0.75).",
408
+ media=[image_path]
409
+ )
410
+ ]
411
+
412
+ # Generate code using VisionAgent
413
+ code_context = agent.generate_code(agent_message)
414
 
415
+ # Load the image for visualization
416
+ image = T.load_image(image_path)
417
+
418
+ # Use multiple models for detection and get high confidence results
419
+ # First try the specialized detector
420
+ detections = T.countgd_object_detection(object_name, image, conf_threshold=0.75)
421
+
422
+ # If no high-confidence detections, try the more general object detector
423
+ if not detections:
424
+ # Try a different model with the same high threshold
425
+ try:
426
+ detections = T.grounding_dino_detection(object_name, image, box_threshold=0.75)
427
+ except:
428
+ pass
429
+
430
+ # Only keep high confidence detections
431
+ high_conf_detections = [det for det in detections if det.get("score", 0) > 0.75]
432
+
433
+ # Visualize only high confidence results with clear labeling
434
+ result_image = T.overlay_bounding_boxes(
435
+ image,
436
+ high_conf_detections,
437
+ labels=[f"{object_name}: {det['score']:.2f}" for det in high_conf_detections]
438
+ )
439
 
440
  # Convert result back to numpy array for display
441
  state.last_prediction = np.array(result_image)
442
 
443
  return {
444
+ "count": len(high_conf_detections),
445
+ "confidence": [det["score"] for det in high_conf_detections],
446
+ "message": f"Detected {len(high_conf_detections)} {object_name}(s) with high confidence (>0.75)"
447
  }
448
  except Exception as e:
449
  print(f"Error in detect_objects: {str(e)}")
 
541
  # Extract objects to detect from user message
542
  objects_to_detect = message.lower()
543
 
544
+ # Format query for object detection - keep it simple and direct
545
+ cleaned_query = objects_to_detect.replace("count", "").replace("detect", "").replace("show", "").strip()
546
+ query = ["a photo of " + cleaned_query]
547
 
548
  messages.append({
549
  "role": "user",