kendrickfff commited on
Commit
1d8f5ea
·
verified ·
1 Parent(s): 9e55ca0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -17
app.py CHANGED
@@ -50,23 +50,19 @@ def analyze_image(image_path):
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
 
53
- # Set a threshold for filtering low-confidence detections (e.g., 0.8)
54
- threshold = 0.8
55
  target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
56
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
57
-
58
- # Filter detections by confidence threshold
59
- if len(results["labels"]) > 0:
60
- detected_objects = []
61
- for idx, label in enumerate(results["labels"]):
62
- if results["scores"][idx] >= threshold: # Check if confidence score is high enough
63
- # Get the object label based on label index
64
- object_name = COCO_CLASSES[label.item()] # Assuming COCO_CLASSES is available
65
- detected_objects.append(object_name)
66
- if detected_objects:
67
- bot_response = f"Objects detected: {', '.join(detected_objects)}."
68
- else:
69
- bot_response = "No objects detected above the confidence threshold."
70
  else:
71
  bot_response = "No objects detected."
72
 
@@ -78,7 +74,6 @@ def analyze_image(image_path):
78
  return chat_history
79
 
80
 
81
-
82
  # Build the Gradio interface
83
  with gr.Blocks() as demo:
84
  gr.Markdown("# Ken Chatbot")
 
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
 
53
+ # Set a target size for post-processing
 
54
  target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
55
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
56
+
57
+ # Collect detected objects
58
+ detected_objects = []
59
+ for idx, label in enumerate(results["labels"]):
60
+ # Get the object label based on label index
61
+ object_name = COCO_CLASSES[label.item()] # Assuming COCO_CLASSES is available
62
+ detected_objects.append(object_name)
63
+
64
+ if detected_objects:
65
+ bot_response = f"Objects detected: {', '.join(detected_objects)}."
 
 
 
66
  else:
67
  bot_response = "No objects detected."
68
 
 
74
  return chat_history
75
 
76
 
 
77
  # Build the Gradio interface
78
  with gr.Blocks() as demo:
79
  gr.Markdown("# Ken Chatbot")