kendrickfff commited on
Commit
9e55ca0
·
verified ·
1 Parent(s): 85dbb04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -42,30 +42,34 @@ def chat_with_gemini(message):
42
  def analyze_image(image_path):
43
  global chat_history
44
  try:
45
- # Open the image file
46
  image = Image.open(image_path).convert("RGB")
47
- # Preprocess the image for DETR
48
  inputs = processor(images=image, return_tensors="pt")
49
 
50
  # Perform inference
51
  with torch.no_grad():
52
  outputs = model(**inputs)
53
 
54
- # Get predictions
 
55
  target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
56
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
57
 
58
- # If any objects detected, display labels
59
  if len(results["labels"]) > 0:
60
- # Decode the labels to human-readable objects
61
  detected_objects = []
62
- for label in results["labels"]:
63
- # Map the label to the COCO dataset categories (Detr uses COCO labels by default)
64
- detected_objects.append(str(label.item())) # Get the integer value (COCO class id)
65
- bot_response = f"Objects detected: {', '.join(detected_objects)}."
 
 
 
 
 
66
  else:
67
  bot_response = "No objects detected."
68
-
69
  chat_history.append(("Uploaded an image for analysis", bot_response))
70
  return chat_history
71
  except Exception as e:
@@ -74,6 +78,7 @@ def analyze_image(image_path):
74
  return chat_history
75
 
76
 
 
77
  # Build the Gradio interface
78
  with gr.Blocks() as demo:
79
  gr.Markdown("# Ken Chatbot")
 
42
  def analyze_image(image_path):
43
  global chat_history
44
  try:
45
+ # Open and preprocess the image
46
  image = Image.open(image_path).convert("RGB")
 
47
  inputs = processor(images=image, return_tensors="pt")
48
 
49
  # Perform inference
50
  with torch.no_grad():
51
  outputs = model(**inputs)
52
 
53
+ # Set a threshold for filtering low-confidence detections (e.g., 0.8)
54
+ threshold = 0.8
55
  target_sizes = torch.tensor([image.size[::-1]]) # (height, width)
56
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
57
 
58
+ # Filter detections by confidence threshold
59
  if len(results["labels"]) > 0:
 
60
  detected_objects = []
61
+ for idx, label in enumerate(results["labels"]):
62
+ if results["scores"][idx] >= threshold: # Check if confidence score is high enough
63
+ # Get the object label based on label index
64
+ object_name = COCO_CLASSES[label.item()] # Assuming COCO_CLASSES is available
65
+ detected_objects.append(object_name)
66
+ if detected_objects:
67
+ bot_response = f"Objects detected: {', '.join(detected_objects)}."
68
+ else:
69
+ bot_response = "No objects detected above the confidence threshold."
70
  else:
71
  bot_response = "No objects detected."
72
+
73
  chat_history.append(("Uploaded an image for analysis", bot_response))
74
  return chat_history
75
  except Exception as e:
 
78
  return chat_history
79
 
80
 
81
+
82
  # Build the Gradio interface
83
  with gr.Blocks() as demo:
84
  gr.Markdown("# Ken Chatbot")