Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from torchvision import models | |
from PIL import Image | |
import cv2 | |
from ultralytics import YOLO | |
import os | |
from streamlit_image_coordinates import streamlit_image_coordinates | |
# Set page config | |
st.set_page_config( | |
page_title="Traffic Light Detection App", | |
layout="wide", | |
menu_items={ | |
'Get Help': 'https://github.com/yourusername/traffic-light-detection', | |
'Report a bug': "https://github.com/yourusername/traffic-light-detection/issues", | |
'About': "# Traffic Light Detection App\nThis app detects traffic lights and monitors objects in a protection area." | |
} | |
) | |
# Define allowed classes | |
ALLOWED_CLASSES = { | |
'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', | |
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe' | |
} | |
def initialize_models(): | |
try: | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize MobileNetV3 model | |
model = models.mobilenet_v3_small(weights=None) | |
model.classifier = nn.Sequential( | |
nn.Linear(576, 2), # Direct mapping to output classes | |
nn.Softmax(dim=1) | |
) | |
model = model.to(device) | |
# Load model weights | |
best_model_path = "best_model_mobilenet_v3_v2.pth" | |
if not os.path.exists(best_model_path): | |
st.error(f"Model file not found: {best_model_path}") | |
return None, None, None | |
if device.type == 'cuda': | |
model.load_state_dict(torch.load(best_model_path)) | |
else: | |
model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu'))) | |
model.eval() | |
# Load YOLO model | |
yolo_model_path = "yolo11s.onnx" # Going up one directory since the app.py is in API22_FEB | |
if not os.path.exists(yolo_model_path): | |
st.error(f"YOLO model file not found: {yolo_model_path}") | |
return device, model, None | |
yolo_model = YOLO(yolo_model_path) | |
return device, model, yolo_model | |
except Exception as e: | |
st.error(f"Error initializing models: {str(e)}") | |
return None, None, None | |
def process_image(image, model, device): | |
# Define image transformations | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Process image | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
# Perform inference | |
with torch.no_grad(): | |
output = model(input_tensor) | |
probabilities = output[0] # Get probabilities for both classes | |
# Class 0 is "No Red Light", Class 1 is "Red Light" | |
no_red_light_prob = probabilities[0].item() | |
red_light_prob = probabilities[1].item() | |
is_red_light = red_light_prob > no_red_light_prob | |
return is_red_light, red_light_prob, no_red_light_prob | |
def is_point_in_polygon(point, polygon): | |
"""Check if a point is inside a polygon using ray casting algorithm.""" | |
x, y = point | |
n = len(polygon) | |
inside = False | |
p1x, p1y = polygon[0] | |
for i in range(n + 1): | |
p2x, p2y = polygon[i % n] | |
if y > min(p1y, p2y): | |
if y <= max(p1y, p2y): | |
if x <= max(p1x, p2x): | |
if p1y != p2y: | |
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x | |
if p1x == p2x or x <= xinters: | |
inside = not inside | |
p1x, p1y = p2x, p2y | |
return inside | |
def is_bbox_in_area(bbox, protection_area, image_shape): | |
"""Check if bounding box center is in protection area.""" | |
# Get bbox center point | |
center_x = (bbox[0] + bbox[2]) / 2 | |
center_y = (bbox[1] + bbox[3]) / 2 | |
return is_point_in_polygon((center_x, center_y), protection_area) | |
def put_text_with_background(img, text, position, font_scale=0.8, thickness=2, font=cv2.FONT_HERSHEY_SIMPLEX): | |
"""Put text with background on image.""" | |
# Get text size | |
(text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness) | |
# Calculate background rectangle | |
padding = 5 | |
bg_rect_pt1 = (position[0], position[1] - text_height - padding) | |
bg_rect_pt2 = (position[0] + text_width + padding * 2, position[1] + padding) | |
# Draw background rectangle | |
cv2.rectangle(img, bg_rect_pt1, bg_rect_pt2, (0, 0, 0), -1) | |
# Put text | |
cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness) | |
def calculate_iou(box1, box2): | |
"""Calculate Intersection over Union between two bounding boxes.""" | |
x1 = max(box1[0], box2[0]) | |
y1 = max(box1[1], box2[1]) | |
x2 = min(box1[2], box2[2]) | |
y2 = min(box1[3], box2[3]) | |
intersection = max(0, x2 - x1) * max(0, y2 - y1) | |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
union = box1_area + box2_area - intersection | |
return intersection / union if union > 0 else 0 | |
def merge_overlapping_detections(detections, iou_threshold=0.5): | |
"""Merge overlapping detections of the same class.""" | |
if not detections: | |
return [] | |
# Sort detections by confidence | |
detections = sorted(detections, key=lambda x: x['confidence'], reverse=True) | |
merged_detections = [] | |
while detections: | |
best_detection = detections.pop(0) | |
i = 0 | |
while i < len(detections): | |
current_detection = detections[i] | |
if (current_detection['class'] == best_detection['class'] and | |
calculate_iou(current_detection['bbox'], best_detection['bbox']) >= iou_threshold): | |
# Remove the lower confidence detection | |
detections.pop(i) | |
else: | |
i += 1 | |
merged_detections.append(best_detection) | |
return merged_detections | |
def main(): | |
st.title("Traffic Light Detection with Protection Area") | |
# Initialize session state for protection area points | |
if 'points' not in st.session_state: | |
st.session_state.points = [] | |
if 'processing_done' not in st.session_state: | |
st.session_state.processing_done = False | |
# File uploader | |
uploaded_file = st.file_uploader("Choose an image", type=['jpg', 'jpeg', 'png']) | |
if uploaded_file is not None: | |
# Convert uploaded file to PIL Image | |
image = Image.open(uploaded_file).convert('RGB') | |
# Convert to OpenCV format for drawing | |
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
height, width = cv_image.shape[:2] | |
# Create a copy for drawing | |
draw_image = cv_image.copy() | |
# Instructions | |
st.write("👆 Click directly on the image to add points for the protection area (need 4 points)") | |
st.write("🔄 Click 'Reset Points' to start over") | |
# Reset button | |
if st.button('Reset Points'): | |
st.session_state.points = [] | |
st.session_state.processing_done = False | |
st.rerun() | |
# Display current image with points | |
if len(st.session_state.points) > 0: | |
# Draw existing points and lines | |
points = np.array(st.session_state.points, dtype=np.int32) | |
cv2.polylines(draw_image, [points], | |
True if len(points) == 4 else False, | |
(0, 255, 0), 2) | |
# Draw points with numbers | |
for i, point in enumerate(points): | |
cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1) | |
cv2.putText(draw_image, str(i+1), | |
(point[0]+10, point[1]+10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
# Create columns for better layout | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
# Display the image and handle click events | |
if len(st.session_state.points) < 4 and not st.session_state.processing_done: | |
# Create a placeholder for the image | |
image_placeholder = st.empty() | |
# Display the image with current points | |
clicked = streamlit_image_coordinates( | |
cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB), | |
key=f"image_coordinates_{len(st.session_state.points)}" | |
) | |
# Handle click events | |
if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None: | |
x, y = clicked['x'], clicked['y'] | |
if 0 <= x < width and 0 <= y < height: | |
# Add new point | |
new_points = st.session_state.points.copy() | |
new_points.append([x, y]) | |
st.session_state.points = new_points | |
# Update the image with the new point | |
points = np.array(st.session_state.points, dtype=np.int32) | |
if len(points) > 0: | |
cv2.polylines(draw_image, [points], | |
True if len(points) == 4 else False, | |
(0, 255, 0), 2) | |
for i, point in enumerate(points): | |
cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1) | |
cv2.putText(draw_image, str(i+1), | |
(point[0]+10, point[1]+10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
# Rerun to update the display | |
st.rerun() | |
else: | |
# Just display the image if we're done adding points | |
st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB), use_column_width=True) | |
with col2: | |
# Show progress | |
st.write(f"Points: {len(st.session_state.points)}/4") | |
# Show current points | |
if len(st.session_state.points) > 0: | |
st.write("Current Points:") | |
for i, point in enumerate(st.session_state.points): | |
st.write(f"Point {i+1}: ({point[0]}, {point[1]})") | |
# Add option to remove last point | |
if st.button("Remove Last Point"): | |
st.session_state.points.pop() | |
st.rerun() | |
# Process button | |
if len(st.session_state.points) == 4 and not st.session_state.processing_done: | |
st.write("✅ Protection area defined! Click 'Process Detection' to continue.") | |
if st.button('Process Detection', type='primary'): | |
st.session_state.processing_done = True | |
# Initialize models | |
device, model, yolo_model = initialize_models() | |
if device is None or model is None: | |
st.error("Failed to initialize models. Please check the error messages above.") | |
return | |
# Process image for red light detection | |
is_red_light, red_light_prob, no_red_light_prob = process_image(image, model, device) | |
# Display red light detection results | |
st.write("\n🔥 Red Light Detection Results:") | |
st.write(f"Red Light Detected: {is_red_light}") | |
st.write(f"Red Light Probability: {red_light_prob:.2%}") | |
st.write(f"No Red Light Probability: {no_red_light_prob:.2%}") | |
if is_red_light and yolo_model is not None: | |
# Draw protection area | |
cv2.polylines(cv_image, [np.array(st.session_state.points)], True, (0, 255, 0), 2) | |
# Run YOLO detection | |
results = yolo_model(cv_image, conf=0.25) | |
# Process detections | |
detection_results = [] | |
for result in results: | |
if result.boxes is not None: | |
for box in result.boxes: | |
class_id = int(box.cls[0]) | |
class_name = yolo_model.names[class_id] | |
if class_name in ALLOWED_CLASSES: | |
bbox = box.xyxy[0].cpu().numpy() | |
if is_bbox_in_area(bbox, st.session_state.points, cv_image.shape): | |
confidence = float(box.conf[0]) | |
detection_results.append({ | |
'class': class_name, | |
'confidence': confidence, | |
'bbox': bbox | |
}) | |
# Merge overlapping detections | |
detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5) | |
# Draw detections | |
for det in detection_results: | |
bbox = det['bbox'] | |
# Draw detection box | |
cv2.rectangle(cv_image, | |
(int(bbox[0]), int(bbox[1])), | |
(int(bbox[2]), int(bbox[3])), | |
(0, 0, 255), 2) | |
# Add label | |
text = f"{det['class']}: {det['confidence']:.2%}" | |
put_text_with_background(cv_image, text, | |
(int(bbox[0]), int(bbox[1]) - 10)) | |
# Add status text | |
status_text = f"Red Light: DETECTED ({red_light_prob:.1%})" | |
put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2) | |
count_text = f"Objects in Protection Area: {len(detection_results)}" | |
put_text_with_background(cv_image, count_text, (10, 70), font_scale=0.8) | |
# Display results | |
st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) | |
# Display detections | |
if detection_results: | |
st.write("\n🎯 Detected Objects in Protection Area:") | |
for i, det in enumerate(detection_results, 1): | |
st.write(f"\nObject {i}:") | |
st.write(f"- Class: {det['class']}") | |
st.write(f"- Confidence: {det['confidence']:.2%}") | |
else: | |
st.write("\nNo objects detected in protection area") | |
else: | |
status_text = f"Red Light: NOT DETECTED ({red_light_prob:.1%})" | |
put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2) | |
st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) | |
if __name__ == "__main__": | |
main() |