Sompote's picture
Upload app.py
99b3da4 verified
raw
history blame
15.8 kB
import streamlit as st
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import cv2
from ultralytics import YOLO
import os
from streamlit_image_coordinates import streamlit_image_coordinates
# Set page config
st.set_page_config(
page_title="Traffic Light Detection App",
layout="wide",
menu_items={
'Get Help': 'https://github.com/yourusername/traffic-light-detection',
'Report a bug': "https://github.com/yourusername/traffic-light-detection/issues",
'About': "# Traffic Light Detection App\nThis app detects traffic lights and monitors objects in a protection area."
}
)
# Define allowed classes
ALLOWED_CLASSES = {
'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck',
'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe'
}
@st.cache_resource
def initialize_models():
try:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Initialize MobileNetV3 model
model = models.mobilenet_v3_small(weights=None)
model.classifier = nn.Sequential(
nn.Linear(576, 2), # Direct mapping to output classes
nn.Softmax(dim=1)
)
model = model.to(device)
# Load model weights
best_model_path = "best_model_mobilenet_v3_v2.pth"
if not os.path.exists(best_model_path):
st.error(f"Model file not found: {best_model_path}")
return None, None, None
if device.type == 'cuda':
model.load_state_dict(torch.load(best_model_path))
else:
model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu')))
model.eval()
# Load YOLO model
yolo_model_path = "yolo11s.onnx" # Going up one directory since the app.py is in API22_FEB
if not os.path.exists(yolo_model_path):
st.error(f"YOLO model file not found: {yolo_model_path}")
return device, model, None
yolo_model = YOLO(yolo_model_path)
return device, model, yolo_model
except Exception as e:
st.error(f"Error initializing models: {str(e)}")
return None, None, None
def process_image(image, model, device):
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# Process image
input_tensor = transform(image).unsqueeze(0).to(device)
# Perform inference
with torch.no_grad():
output = model(input_tensor)
probabilities = output[0] # Get probabilities for both classes
# Class 0 is "No Red Light", Class 1 is "Red Light"
no_red_light_prob = probabilities[0].item()
red_light_prob = probabilities[1].item()
is_red_light = red_light_prob > no_red_light_prob
return is_red_light, red_light_prob, no_red_light_prob
def is_point_in_polygon(point, polygon):
"""Check if a point is inside a polygon using ray casting algorithm."""
x, y = point
n = len(polygon)
inside = False
p1x, p1y = polygon[0]
for i in range(n + 1):
p2x, p2y = polygon[i % n]
if y > min(p1y, p2y):
if y <= max(p1y, p2y):
if x <= max(p1x, p2x):
if p1y != p2y:
xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
if p1x == p2x or x <= xinters:
inside = not inside
p1x, p1y = p2x, p2y
return inside
def is_bbox_in_area(bbox, protection_area, image_shape):
"""Check if bounding box center is in protection area."""
# Get bbox center point
center_x = (bbox[0] + bbox[2]) / 2
center_y = (bbox[1] + bbox[3]) / 2
return is_point_in_polygon((center_x, center_y), protection_area)
def put_text_with_background(img, text, position, font_scale=0.8, thickness=2, font=cv2.FONT_HERSHEY_SIMPLEX):
"""Put text with background on image."""
# Get text size
(text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
# Calculate background rectangle
padding = 5
bg_rect_pt1 = (position[0], position[1] - text_height - padding)
bg_rect_pt2 = (position[0] + text_width + padding * 2, position[1] + padding)
# Draw background rectangle
cv2.rectangle(img, bg_rect_pt1, bg_rect_pt2, (0, 0, 0), -1)
# Put text
cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
def calculate_iou(box1, box2):
"""Calculate Intersection over Union between two bounding boxes."""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])
intersection = max(0, x2 - x1) * max(0, y2 - y1)
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
union = box1_area + box2_area - intersection
return intersection / union if union > 0 else 0
def merge_overlapping_detections(detections, iou_threshold=0.5):
"""Merge overlapping detections of the same class."""
if not detections:
return []
# Sort detections by confidence
detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
merged_detections = []
while detections:
best_detection = detections.pop(0)
i = 0
while i < len(detections):
current_detection = detections[i]
if (current_detection['class'] == best_detection['class'] and
calculate_iou(current_detection['bbox'], best_detection['bbox']) >= iou_threshold):
# Remove the lower confidence detection
detections.pop(i)
else:
i += 1
merged_detections.append(best_detection)
return merged_detections
def main():
st.title("Traffic Light Detection with Protection Area")
# Initialize session state for protection area points
if 'points' not in st.session_state:
st.session_state.points = []
if 'processing_done' not in st.session_state:
st.session_state.processing_done = False
# File uploader
uploaded_file = st.file_uploader("Choose an image", type=['jpg', 'jpeg', 'png'])
if uploaded_file is not None:
# Convert uploaded file to PIL Image
image = Image.open(uploaded_file).convert('RGB')
# Convert to OpenCV format for drawing
cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
height, width = cv_image.shape[:2]
# Create a copy for drawing
draw_image = cv_image.copy()
# Instructions
st.write("👆 Click directly on the image to add points for the protection area (need 4 points)")
st.write("🔄 Click 'Reset Points' to start over")
# Reset button
if st.button('Reset Points'):
st.session_state.points = []
st.session_state.processing_done = False
st.rerun()
# Display current image with points
if len(st.session_state.points) > 0:
# Draw existing points and lines
points = np.array(st.session_state.points, dtype=np.int32)
cv2.polylines(draw_image, [points],
True if len(points) == 4 else False,
(0, 255, 0), 2)
# Draw points with numbers
for i, point in enumerate(points):
cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
cv2.putText(draw_image, str(i+1),
(point[0]+10, point[1]+10),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
# Create columns for better layout
col1, col2 = st.columns([4, 1])
with col1:
# Display the image and handle click events
if len(st.session_state.points) < 4 and not st.session_state.processing_done:
# Create a placeholder for the image
image_placeholder = st.empty()
# Display the image with current points
clicked = streamlit_image_coordinates(
cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB),
key=f"image_coordinates_{len(st.session_state.points)}"
)
# Handle click events
if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None:
x, y = clicked['x'], clicked['y']
if 0 <= x < width and 0 <= y < height:
# Add new point
new_points = st.session_state.points.copy()
new_points.append([x, y])
st.session_state.points = new_points
# Update the image with the new point
points = np.array(st.session_state.points, dtype=np.int32)
if len(points) > 0:
cv2.polylines(draw_image, [points],
True if len(points) == 4 else False,
(0, 255, 0), 2)
for i, point in enumerate(points):
cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
cv2.putText(draw_image, str(i+1),
(point[0]+10, point[1]+10),
cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
# Rerun to update the display
st.rerun()
else:
# Just display the image if we're done adding points
st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB), use_column_width=True)
with col2:
# Show progress
st.write(f"Points: {len(st.session_state.points)}/4")
# Show current points
if len(st.session_state.points) > 0:
st.write("Current Points:")
for i, point in enumerate(st.session_state.points):
st.write(f"Point {i+1}: ({point[0]}, {point[1]})")
# Add option to remove last point
if st.button("Remove Last Point"):
st.session_state.points.pop()
st.rerun()
# Process button
if len(st.session_state.points) == 4 and not st.session_state.processing_done:
st.write("✅ Protection area defined! Click 'Process Detection' to continue.")
if st.button('Process Detection', type='primary'):
st.session_state.processing_done = True
# Initialize models
device, model, yolo_model = initialize_models()
if device is None or model is None:
st.error("Failed to initialize models. Please check the error messages above.")
return
# Process image for red light detection
is_red_light, red_light_prob, no_red_light_prob = process_image(image, model, device)
# Display red light detection results
st.write("\n🔥 Red Light Detection Results:")
st.write(f"Red Light Detected: {is_red_light}")
st.write(f"Red Light Probability: {red_light_prob:.2%}")
st.write(f"No Red Light Probability: {no_red_light_prob:.2%}")
if is_red_light and yolo_model is not None:
# Draw protection area
cv2.polylines(cv_image, [np.array(st.session_state.points)], True, (0, 255, 0), 2)
# Run YOLO detection
results = yolo_model(cv_image, conf=0.25)
# Process detections
detection_results = []
for result in results:
if result.boxes is not None:
for box in result.boxes:
class_id = int(box.cls[0])
class_name = yolo_model.names[class_id]
if class_name in ALLOWED_CLASSES:
bbox = box.xyxy[0].cpu().numpy()
if is_bbox_in_area(bbox, st.session_state.points, cv_image.shape):
confidence = float(box.conf[0])
detection_results.append({
'class': class_name,
'confidence': confidence,
'bbox': bbox
})
# Merge overlapping detections
detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5)
# Draw detections
for det in detection_results:
bbox = det['bbox']
# Draw detection box
cv2.rectangle(cv_image,
(int(bbox[0]), int(bbox[1])),
(int(bbox[2]), int(bbox[3])),
(0, 0, 255), 2)
# Add label
text = f"{det['class']}: {det['confidence']:.2%}"
put_text_with_background(cv_image, text,
(int(bbox[0]), int(bbox[1]) - 10))
# Add status text
status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
count_text = f"Objects in Protection Area: {len(detection_results)}"
put_text_with_background(cv_image, count_text, (10, 70), font_scale=0.8)
# Display results
st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
# Display detections
if detection_results:
st.write("\n🎯 Detected Objects in Protection Area:")
for i, det in enumerate(detection_results, 1):
st.write(f"\nObject {i}:")
st.write(f"- Class: {det['class']}")
st.write(f"- Confidence: {det['confidence']:.2%}")
else:
st.write("\nNo objects detected in protection area")
else:
status_text = f"Red Light: NOT DETECTED ({red_light_prob:.1%})"
put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
if __name__ == "__main__":
main()