import streamlit as st from ultralytics import YOLO import cv2 import time import numpy as np import torch from PIL import Image import tempfile import warnings warnings.filterwarnings('ignore') def get_direction(old_center, new_center, min_movement=10): if old_center is None or new_center is None: return "stationary" dx = new_center[0] - old_center[0] dy = new_center[1] - old_center[1] if abs(dx) < min_movement and abs(dy) < min_movement: return "stationary" if abs(dx) > abs(dy): return "right" if dx > 0 else "left" else: return "down" if dy > 0 else "up" class ObjectTracker: def __init__(self): self.tracked_objects = {} self.object_count = {} def update(self, detections): current_objects = {} results = [] for detection in detections: x1, y1, x2, y2 = detection[0:4] center = ((x1 + x2) // 2, (y1 + y2) // 2) class_id = detection[5] object_id = f"{class_id}_{len(self.object_count.get(class_id, []))}" min_dist = float('inf') closest_id = None for prev_id, prev_data in self.tracked_objects.items(): if prev_id.split('_')[0] == str(class_id): dist = np.sqrt((center[0] - prev_data['center'][0])**2 + (center[1] - prev_data['center'][1])**2) if dist < min_dist and dist < 100: min_dist = dist closest_id = prev_id if closest_id: object_id = closest_id else: if class_id not in self.object_count: self.object_count[class_id] = [] self.object_count[class_id].append(object_id) prev_center = self.tracked_objects.get(object_id, {}).get('center', None) direction = get_direction(prev_center, center) current_objects[object_id] = { 'center': center, 'direction': direction, 'detection': detection } results.append((detection, object_id, direction)) self.tracked_objects = current_objects return results def main(): st.title("Real-time Object Detection with Direction") # File uploader for video uploaded_file = st.file_uploader("Choose a video file", type=['mp4', 'avi', 'mov']) # Add start button start_detection = st.button("Start Detection") # Add stop button stop_detection = st.button("Stop Detection") if uploaded_file is not None and start_detection: # Create a session state to track if detection is running if 'running' not in st.session_state: st.session_state.running = True # Save uploaded file temporarily tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(uploaded_file.read()) # Load model with st.spinner('Loading model...'): model = YOLO('yolov8x.pt',verbose=False) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.to(device) tracker = ObjectTracker() cap = cv2.VideoCapture(tfile.name) direction_colors = { "left": (255, 0, 0), "right": (0, 255, 0), "up": (0, 255, 255), "down": (0, 0, 255), "stationary": (128, 128, 128) } # Create placeholder for video frame frame_placeholder = st.empty() # Create placeholder for detection info info_placeholder = st.empty() st.success("Detection Started!") while cap.isOpened() and st.session_state.running: success, frame = cap.read() if not success: break # Run detection results = model(frame, conf=0.25, iou=0.45, max_det=20, verbose=False)[0] detections = [] for box in results.boxes.data: x1, y1, x2, y2, conf, cls = box.tolist() detections.append([int(x1), int(y1), int(x2), int(y2), float(conf), int(cls)]) tracked_objects = tracker.update(detections) # Dictionary to store detection counts detection_counts = {} for detection, obj_id, direction in tracked_objects: x1, y1, x2, y2, conf, cls = detection color = direction_colors.get(direction, (128, 128, 128)) cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2) label = f"{model.names[int(cls)]} {direction} {conf:.2f}" # Increased font size and thickness font_scale = 1.2 thickness = 3 text_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)[0] # Increased padding for label background padding_y = 15 cv2.rectangle(frame, (int(x1), int(y1) - text_size[1] - padding_y), (int(x1) + text_size[0], int(y1)), color, -1) cv2.putText(frame, label, (int(x1), int(y1) - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), thickness) # Count detections by class class_name = model.names[int(cls)] detection_counts[class_name] = detection_counts.get(class_name, 0) + 1 # Convert BGR to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Update frame frame_placeholder.image(frame_rgb, channels="RGB", use_column_width=True) # Update detection info info_text = "Detected Objects:\n" for class_name, count in detection_counts.items(): info_text += f"{class_name}: {count}\n" info_placeholder.text(info_text) # Check if stop button is pressed if stop_detection: st.session_state.running = False break cap.release() st.session_state.running = False st.warning("Detection Stopped") elif uploaded_file is None and start_detection: st.error("Please upload a video file first!") if __name__ == "__main__": main()