Spaces:
Running
Running
import streamlit as st | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from torchvision import models | |
from PIL import Image | |
import cv2 | |
from ultralytics import YOLO | |
import os | |
import random | |
from streamlit_image_coordinates import streamlit_image_coordinates | |
# Set page config | |
st.set_page_config( | |
page_title="Traffic Light Detection App", | |
layout="wide", | |
menu_items={ | |
'Get Help': 'https://github.com/yourusername/traffic-light-detection', | |
'Report a bug': "https://github.com/yourusername/traffic-light-detection/issues", | |
'About': "# Traffic Light Detection App\nThis app detects traffic lights and monitors objects in a protection area." | |
} | |
) | |
# Define allowed classes | |
ALLOWED_CLASSES = { | |
'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', | |
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe' | |
} | |
def initialize_models(): | |
try: | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Initialize MobileNetV3 model | |
model = models.mobilenet_v3_small(weights=None) | |
model.classifier = nn.Sequential( | |
nn.Linear(576, 2), # Direct mapping to output classes | |
nn.Softmax(dim=1) | |
) | |
model = model.to(device) | |
# Load model weights | |
best_model_path = "best_model_mobilenet_v3_v2.pth" | |
if not os.path.exists(best_model_path): | |
st.error(f"Model file not found: {best_model_path}") | |
return None, None, None, None | |
if device.type == 'cuda': | |
model.load_state_dict(torch.load(best_model_path)) | |
else: | |
model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu'))) | |
model.eval() | |
# Load YOLO model for object detection | |
yolo_model_path = "yolo11s.onnx" | |
if not os.path.exists(yolo_model_path): | |
st.error(f"YOLO model file not found: {yolo_model_path}") | |
return device, model, None, None | |
yolo_model = YOLO(yolo_model_path) | |
# Load YOLO segmentation model | |
seg_model_path = "best_segment.pt" | |
if not os.path.exists(seg_model_path): | |
st.error(f"YOLO segmentation model file not found: {seg_model_path}") | |
return device, model, yolo_model, None | |
seg_model = YOLO(seg_model_path) | |
return device, model, yolo_model, seg_model | |
except Exception as e: | |
st.error(f"Error initializing models: {str(e)}") | |
return None, None, None, None | |
def process_image(image, model, device): | |
# Define image transformations | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Process image | |
input_tensor = transform(image).unsqueeze(0).to(device) | |
# Perform inference | |
with torch.no_grad(): | |
output = model(input_tensor) | |
probabilities = output[0] # Get probabilities for both classes | |
# Class 0 is "No Red Light", Class 1 is "Red Light" | |
no_red_light_prob = probabilities[0].item() | |
red_light_prob = probabilities[1].item() | |
is_red_light = red_light_prob > no_red_light_prob | |
return is_red_light, red_light_prob, no_red_light_prob | |
def is_point_in_polygon(point, polygon): | |
"""Check if a point is inside a polygon using ray casting algorithm.""" | |
x, y = point | |
n = len(polygon) | |
inside = False | |
p1x, p1y = polygon[0] | |
for i in range(n + 1): | |
p2x, p2y = polygon[i % n] | |
if y > min(p1y, p2y): | |
if y <= max(p1y, p2y): | |
if x <= max(p1x, p2x): | |
if p1y != p2y: | |
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x | |
if p1x == p2x or x <= xinters: | |
inside = not inside | |
p1x, p1y = p2x, p2y | |
return inside | |
def is_bbox_in_area(bbox, protection_area, image_shape): | |
"""Check if bounding box center is in protection area.""" | |
# Get bbox center point | |
center_x = (bbox[0] + bbox[2]) / 2 | |
center_y = (bbox[1] + bbox[3]) / 2 | |
return is_point_in_polygon((center_x, center_y), protection_area) | |
def put_text_with_background(img, text, position, font_scale=0.8, thickness=2, font=cv2.FONT_HERSHEY_SIMPLEX): | |
"""Put text with background on image.""" | |
# Get text size | |
(text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness) | |
# Calculate background rectangle | |
padding = 5 | |
bg_rect_pt1 = (position[0], position[1] - text_height - padding) | |
bg_rect_pt2 = (position[0] + text_width + padding * 2, position[1] + padding) | |
# Draw background rectangle | |
cv2.rectangle(img, bg_rect_pt1, bg_rect_pt2, (0, 0, 0), -1) | |
# Put text | |
cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness) | |
def calculate_iou(box1, box2): | |
"""Calculate Intersection over Union between two bounding boxes.""" | |
x1 = max(box1[0], box2[0]) | |
y1 = max(box1[1], box2[1]) | |
x2 = min(box1[2], box2[2]) | |
y2 = min(box1[3], box2[3]) | |
intersection = max(0, x2 - x1) * max(0, y2 - y1) | |
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
union = box1_area + box2_area - intersection | |
return intersection / union if union > 0 else 0 | |
def merge_overlapping_detections(detections, iou_threshold=0.5): | |
"""Merge overlapping detections of the same class.""" | |
if not detections: | |
return [] | |
# Sort detections by confidence | |
detections = sorted(detections, key=lambda x: x['confidence'], reverse=True) | |
merged_detections = [] | |
while detections: | |
best_detection = detections.pop(0) | |
i = 0 | |
while i < len(detections): | |
current_detection = detections[i] | |
if (current_detection['class'] == best_detection['class'] and | |
calculate_iou(current_detection['bbox'], best_detection['bbox']) >= iou_threshold): | |
# Remove the lower confidence detection | |
detections.pop(i) | |
else: | |
i += 1 | |
merged_detections.append(best_detection) | |
return merged_detections | |
def get_segmentation_masks(image, seg_model, conf_threshold=0.25): | |
"""Get segmentation masks from YOLO segmentation model.""" | |
results = seg_model(image, conf=conf_threshold) | |
masks = [] | |
if results and len(results) > 0 and results[0].masks is not None: | |
for i, mask in enumerate(results[0].masks.xy): | |
class_id = int(results[0].boxes.cls[i]) | |
class_name = results[0].names[class_id] | |
confidence = float(results[0].boxes.conf[i]) | |
# Convert mask to numpy array | |
mask_np = np.array(mask, dtype=np.int32) | |
masks.append({ | |
'mask': mask_np, | |
'class': class_name, | |
'confidence': confidence, | |
'class_id': class_id | |
}) | |
return masks, results | |
def main(): | |
st.title("Train obstruction detection V1.2") | |
# Initialize session state | |
if 'points' not in st.session_state: | |
st.session_state.points = [] | |
if 'protection_area_defined' not in st.session_state: | |
st.session_state.protection_area_defined = False | |
if 'current_step' not in st.session_state: | |
st.session_state.current_step = 1 | |
if 'protection_method' not in st.session_state: | |
st.session_state.protection_method = "manual" | |
if 'segmentation_masks' not in st.session_state: | |
st.session_state.segmentation_masks = [] | |
if 'selected_mask_index' not in st.session_state: | |
st.session_state.selected_mask_index = -1 | |
# Initialize models | |
device, model, yolo_model, seg_model = initialize_models() | |
# Create tabs for the two steps | |
step1, step2 = st.tabs(["Step 1: Define Protection Area", "Step 2: Detect Objects"]) | |
with step1: | |
st.header("Step 1: Define Protection Area") | |
# Method selection | |
method = st.radio( | |
"Select method to define protection area:", | |
["Manual (Click 4 points)", "Automatic Segmentation (Select a segment)"], | |
index=0 if st.session_state.protection_method == "manual" else 1, | |
key="method_selection" | |
) | |
# Update protection method in session state | |
st.session_state.protection_method = "manual" if method == "Manual (Click 4 points)" else "yolo" | |
# File uploader for protection area definition | |
setup_image = st.file_uploader("Choose an image for protection area setup", type=['jpg', 'jpeg', 'png'], key="setup_image") | |
if setup_image is not None: | |
# Convert uploaded file to PIL Image | |
image = Image.open(setup_image).convert('RGB') | |
# Convert to OpenCV format for drawing | |
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
height, width = cv_image.shape[:2] | |
# Create a copy for drawing | |
draw_image = cv_image.copy() | |
# Reset button | |
if st.button('Reset Points/Selection'): | |
st.session_state.points = [] | |
st.session_state.protection_area_defined = False | |
st.session_state.selected_mask_index = -1 | |
# Clear segmentation masks to force re-detection | |
st.session_state.segmentation_masks = [] | |
if 'mask_colors' in st.session_state: | |
del st.session_state.mask_colors | |
st.rerun() | |
# Manual method | |
if st.session_state.protection_method == "manual": | |
# Instructions | |
st.write("👆 Click directly on the image to add points for the protection area (need 4 points)") | |
# Display current image with points | |
if len(st.session_state.points) > 0: | |
# Draw existing points and lines | |
points = np.array(st.session_state.points, dtype=np.int32) | |
cv2.polylines(draw_image, [points], | |
True if len(points) == 4 else False, | |
(0, 255, 0), 2) | |
# Draw points with numbers | |
for i, point in enumerate(points): | |
cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1) | |
cv2.putText(draw_image, str(i+1), | |
(point[0]+10, point[1]+10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2) | |
# Create columns for better layout | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
# Display the image and handle click events | |
if len(st.session_state.points) < 4: | |
clicked = streamlit_image_coordinates( | |
cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB), | |
key=f"image_coordinates_{len(st.session_state.points)}" | |
) | |
if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None: | |
x, y = clicked['x'], clicked['y'] | |
if 0 <= x < width and 0 <= y < height: | |
st.session_state.points.append([x, y]) | |
if len(st.session_state.points) == 4: | |
st.session_state.protection_area_defined = True | |
st.rerun() | |
else: | |
st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB)) | |
with col2: | |
st.write(f"Points: {len(st.session_state.points)}/4") | |
if len(st.session_state.points) > 0: | |
st.write("Current Points:") | |
for i, point in enumerate(st.session_state.points): | |
st.write(f"Point {i+1}: ({point[0]}, {point[1]})") | |
# YOLO Segmentation method | |
else: | |
if seg_model is None: | |
st.error("YOLO segmentation model not loaded. Please check the error messages above.") | |
else: | |
# Always run segmentation when in YOLO mode to ensure fresh results | |
with st.spinner("Running segmentation..."): | |
masks, results = get_segmentation_masks(cv_image, seg_model) | |
st.session_state.segmentation_masks = masks | |
# Generate random colors for each mask | |
st.session_state.mask_colors = [] | |
for _ in range(len(masks)): | |
st.session_state.mask_colors.append([random.randint(0, 255) for _ in range(3)]) | |
# Display segmentation results | |
if len(st.session_state.segmentation_masks) > 0: | |
# Create a copy of the image for drawing masks | |
mask_image = cv_image.copy() | |
# Draw all masks with transparency | |
for i, mask_data in enumerate(st.session_state.segmentation_masks): | |
mask = mask_data['mask'] | |
color = st.session_state.mask_colors[i] | |
# Create a blank image for this mask | |
mask_overlay = np.zeros_like(mask_image) | |
# Draw the filled polygon | |
cv2.fillPoly(mask_overlay, [mask], color) | |
# Add the mask to the image with transparency | |
alpha = 0.4 | |
if i == st.session_state.selected_mask_index: | |
alpha = 0.7 # Make selected mask more visible | |
mask_image = cv2.addWeighted(mask_image, 1, mask_overlay, alpha, 0) | |
# Draw the polygon outline | |
line_thickness = 2 | |
if i == st.session_state.selected_mask_index: | |
line_thickness = 4 # Make selected mask outline thicker | |
cv2.polylines(mask_image, [mask], True, color, line_thickness) | |
# Add class label | |
class_name = mask_data['class'] | |
confidence = mask_data['confidence'] | |
label = f"{class_name} {confidence:.2f}" | |
# Find a good position for the label (use the top-left point of the mask) | |
label_pos = (int(mask[0][0]), int(mask[0][1]) - 10) | |
put_text_with_background(mask_image, label, label_pos) | |
# Display the image with masks | |
col1, col2 = st.columns([4, 1]) | |
with col1: | |
st.image(cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)) | |
with col2: | |
st.write("Available Segments:") | |
for i, mask_data in enumerate(st.session_state.segmentation_masks): | |
if st.button(f"Select {mask_data['class']} #{i+1}", key=f"select_mask_{i}"): | |
st.session_state.selected_mask_index = i | |
# Use the selected mask as protection area | |
st.session_state.points = mask_data['mask'].tolist() | |
st.session_state.protection_area_defined = True | |
st.rerun() | |
# Add a re-detect button | |
if st.button("Re-detect Segments"): | |
st.session_state.segmentation_masks = [] | |
if 'mask_colors' in st.session_state: | |
del st.session_state.mask_colors | |
st.session_state.selected_mask_index = -1 | |
st.rerun() | |
else: | |
st.warning("No segmentation masks found in the image. Try another image or use manual method.") | |
with step2: | |
st.header("Step 2: Detect Objects") | |
if not st.session_state.protection_area_defined: | |
st.warning("⚠️ Please complete Step 1 first to define the protection area.") | |
return | |
st.write("Upload images to detect red lights and objects in the protection area") | |
# File uploader for detection | |
detection_image = st.file_uploader("Choose an image for detection", type=['jpg', 'jpeg', 'png'], key="detection_image") | |
if detection_image is not None: | |
if device is None or model is None: | |
st.error("Failed to initialize models. Please check the error messages above.") | |
return | |
# Load and process image | |
image = Image.open(detection_image).convert('RGB') | |
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
# Process image for red light detection | |
is_red_light, red_light_prob, no_red_light_prob = process_image(image, model, device) | |
# Display red light detection results | |
st.write("\n🔥 Red Light Detection Results:") | |
st.write(f"Red Light Detected: {is_red_light}") | |
st.write(f"Red Light Probability: {red_light_prob:.2%}") | |
st.write(f"No Red Light Probability: {no_red_light_prob:.2%}") | |
if is_red_light and yolo_model is not None: | |
# Draw protection area | |
cv2.polylines(cv_image, [np.array(st.session_state.points)], True, (0, 255, 0), 2) | |
# Run YOLO detection | |
results = yolo_model(cv_image, conf=0.25) | |
# Process detections | |
detection_results = [] | |
for result in results: | |
if result.boxes is not None: | |
for box in result.boxes: | |
class_id = int(box.cls[0]) | |
class_name = yolo_model.names[class_id] | |
if class_name in ALLOWED_CLASSES: | |
bbox = box.xyxy[0].cpu().numpy() | |
if is_bbox_in_area(bbox, st.session_state.points, cv_image.shape): | |
confidence = float(box.conf[0]) | |
detection_results.append({ | |
'class': class_name, | |
'confidence': confidence, | |
'bbox': bbox | |
}) | |
# Merge overlapping detections | |
detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5) | |
# Draw detections | |
for det in detection_results: | |
bbox = det['bbox'] | |
# Draw detection box | |
cv2.rectangle(cv_image, | |
(int(bbox[0]), int(bbox[1])), | |
(int(bbox[2]), int(bbox[3])), | |
(0, 0, 255), 2) | |
# Add label | |
text = f"{det['class']}: {det['confidence']:.2%}" | |
put_text_with_background(cv_image, text, | |
(int(bbox[0]), int(bbox[1]) - 10)) | |
# Add status text | |
status_text = f"Red Light: DETECTED ({red_light_prob:.1%})" | |
put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2) | |
count_text = f"Objects in Protection Area: {len(detection_results)}" | |
put_text_with_background(cv_image, count_text, (10, 70), font_scale=0.8) | |
# Display results | |
st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) | |
# Display detections | |
if detection_results: | |
st.write("\n🎯 Detected Objects in Protection Area:") | |
for i, det in enumerate(detection_results, 1): | |
st.write(f"\nObject {i}:") | |
st.write(f"- Class: {det['class']}") | |
st.write(f"- Confidence: {det['confidence']:.2%}") | |
else: | |
st.write("\nNo objects detected in protection area") | |
else: | |
status_text = f"Red Light: NOT DETECTED ({red_light_prob:.1%})" | |
put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2) | |
st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)) | |
if __name__ == "__main__": | |
main() |