Spaces:
Running
Running
import cv2 | |
import streamlit as st | |
import numpy as np | |
import tempfile | |
import os | |
from ultralytics import YOLO | |
from streamlit_webrtc import (webrtc_streamer, VideoProcessorBase, WebRtcMode, RTCConfiguration) | |
import av | |
from turn import get_ice_servers | |
model = YOLO('yolov8n.pt') | |
# Global variable to store the latest frame with bounding boxes | |
cached_frame = None | |
frame_skip = 5 # Process every 5th frame | |
# # Define a custom video processor class inheriting from VideoProcessorBase | |
# class VideoProcessor(VideoProcessorBase): | |
# def __init__(self): | |
# self.model = model | |
# self.frame_skip = 10 # Class-level variable for frame skipping | |
# self.cached_frame = None # Class-level variable for cached frames | |
def recv(frame: av.VideoFrame) -> av.VideoFrame: | |
# Skip frames to reduce processing load | |
# global frame_skip, cached_frame | |
# if frame_skip > 0: | |
# frame_skip -= 1 | |
# return frame | |
# Reset frame skip | |
# frame_skip = 5 | |
# Convert frame to OpenCV format (BGR) | |
frame_bgr = frame.to_ndarray(format="bgr24") | |
# Resize frame to reduce processing time | |
frame_resized = cv2.resize(frame_bgr, (160, 120)) # Instead of 640x480 | |
# # Detect and track objects using YOLOv8 | |
# results = model.track(frame_resized, persist=True) | |
# # Plot results | |
# frame_annotated = results[0].plot() | |
# # Cache the annotated frame | |
# cached_frame = frame_annotated | |
# Process every nth frame | |
if frame_skip == 0: | |
# Reset the frame skip counter | |
frame_skip = 10 | |
# Detect and track objects using YOLOv8 | |
results = model.track(frame_resized, persist=True) | |
# Plot results | |
frame_annotated = results[0].plot() | |
# Cache the annotated frame | |
cached_frame = frame_annotated | |
else: | |
# Use the cached frame for skipped frames | |
frame_annotated = cached_frame if cached_frame is not None else frame_resized | |
frame_skip -= 1 | |
# Convert frame back to RGB format | |
frame_rgb = cv2.cvtColor(frame_annotated, cv2.COLOR_BGR2RGB) | |
return av.VideoFrame.from_ndarray(frame_rgb, format="rgb24") | |
# Streamlit web app | |
def main(): | |
# Set page title | |
st.set_page_config(page_title="Object Tracking with Streamlit") | |
# Streamlit web app | |
st.title("Object Tracking") | |
# Radio button for user selection | |
option = st.radio("Choose an option:", ("Live Stream", "Upload Video")) | |
if option == "Live Stream": | |
# Start the WebRTC stream with object tracking | |
# WebRTC streamer configuration | |
# Define RTC configuration for WebRTC | |
# RTC_CONFIGURATION = RTCConfiguration({ | |
# "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}] | |
# }) | |
# Start the WebRTC stream with object tracking | |
# webrtc_streamer(key="live-stream", video_frame_callback=recv, | |
# rtc_configuration=rtc_configuration, sendback_audio=False) | |
webrtc_streamer(key="live-stream", | |
#mode=WebRtcMode.SENDRECV, | |
video_frame_callback=recv, | |
rtc_configuration={"iceServers": get_ice_servers()}, | |
media_stream_constraints={"video": True, "audio": False}, | |
async_processing=True) | |
elif option == "Upload Video": | |
# File uploader for video upload | |
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"]) | |
# Button to start tracking | |
start_button_pressed = st.button("Start Tracking") | |
# Placeholder for video frame | |
frame_placeholder = st.empty() | |
# Button to stop tracking | |
stop_button_pressed = st.button("Stop") | |
# Check if the start button is pressed and file is uploaded | |
if start_button_pressed and uploaded_file is not None: | |
# Call the function to track uploaded video with the stop button state | |
track_uploaded_video(uploaded_file, stop_button_pressed, frame_placeholder) | |
# Release resources | |
if uploaded_file: | |
uploaded_file.close() | |
# Function to perform object tracking on uploaded video | |
def track_uploaded_video(video_file, stop_button, frame_placeholder): | |
# Create a temporary file to save the uploaded video | |
temp_video = tempfile.NamedTemporaryFile(delete=False) | |
temp_video.write(video_file.read()) | |
temp_video.close() | |
# OpenCV's VideoCapture for reading video file | |
cap = cv2.VideoCapture(temp_video.name) | |
frame_count = 0 | |
while cap.isOpened() and not stop_button: | |
ret, frame = cap.read() | |
if not ret: | |
st.write("The video capture has ended.") | |
break | |
# Process every 5th frame | |
if frame_count % 5 == 0: | |
# Resize frame to reduce processing time | |
frame_resized = cv2.resize(frame, (640, 480)) | |
# Detect and track objects using YOLOv8 | |
results = model.track(frame_resized, persist=True) | |
# Plot results | |
frame_ = results[0].plot() | |
# Display frame with bounding boxes | |
frame_placeholder.image(frame_, channels="BGR") | |
frame_count += 1 | |
# Release resources | |
cap.release() | |
# Remove temporary file | |
os.remove(temp_video.name) | |
# Run the app | |
if __name__ == "__main__": | |
main() | |