Spaces:
Running
Running
File size: 5,889 Bytes
d6890dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import cv2
import streamlit as st
import numpy as np
import tempfile
import os
from ultralytics import YOLO
from streamlit_webrtc import (webrtc_streamer, VideoProcessorBase, WebRtcMode, RTCConfiguration)
import av
from turn import get_ice_servers
model = YOLO('yolov8n.pt')
# Global variable to store the latest frame with bounding boxes
cached_frame = None
frame_skip = 5 # Process every 5th frame
# # Define a custom video processor class inheriting from VideoProcessorBase
# class VideoProcessor(VideoProcessorBase):
# def __init__(self):
# self.model = model
# self.frame_skip = 10 # Class-level variable for frame skipping
# self.cached_frame = None # Class-level variable for cached frames
def recv(frame: av.VideoFrame) -> av.VideoFrame:
# Skip frames to reduce processing load
# global frame_skip, cached_frame
# if frame_skip > 0:
# frame_skip -= 1
# return frame
# Reset frame skip
# frame_skip = 5
# Convert frame to OpenCV format (BGR)
frame_bgr = frame.to_ndarray(format="bgr24")
# Resize frame to reduce processing time
frame_resized = cv2.resize(frame_bgr, (160, 120)) # Instead of 640x480
# # Detect and track objects using YOLOv8
# results = model.track(frame_resized, persist=True)
# # Plot results
# frame_annotated = results[0].plot()
# # Cache the annotated frame
# cached_frame = frame_annotated
# Process every nth frame
if frame_skip == 0:
# Reset the frame skip counter
frame_skip = 10
# Detect and track objects using YOLOv8
results = model.track(frame_resized, persist=True)
# Plot results
frame_annotated = results[0].plot()
# Cache the annotated frame
cached_frame = frame_annotated
else:
# Use the cached frame for skipped frames
frame_annotated = cached_frame if cached_frame is not None else frame_resized
frame_skip -= 1
# Convert frame back to RGB format
frame_rgb = cv2.cvtColor(frame_annotated, cv2.COLOR_BGR2RGB)
return av.VideoFrame.from_ndarray(frame_rgb, format="rgb24")
# Streamlit web app
def main():
# Set page title
st.set_page_config(page_title="Object Tracking with Streamlit")
# Streamlit web app
st.title("Object Tracking")
# Radio button for user selection
option = st.radio("Choose an option:", ("Live Stream", "Upload Video"))
if option == "Live Stream":
# Start the WebRTC stream with object tracking
# WebRTC streamer configuration
# Define RTC configuration for WebRTC
# RTC_CONFIGURATION = RTCConfiguration({
# "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
# })
# Start the WebRTC stream with object tracking
# webrtc_streamer(key="live-stream", video_frame_callback=recv,
# rtc_configuration=rtc_configuration, sendback_audio=False)
webrtc_streamer(key="live-stream",
#mode=WebRtcMode.SENDRECV,
video_frame_callback=recv,
rtc_configuration={"iceServers": get_ice_servers()},
media_stream_constraints={"video": True, "audio": False},
async_processing=True)
elif option == "Upload Video":
# File uploader for video upload
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
# Button to start tracking
start_button_pressed = st.button("Start Tracking")
# Placeholder for video frame
frame_placeholder = st.empty()
# Button to stop tracking
stop_button_pressed = st.button("Stop")
# Check if the start button is pressed and file is uploaded
if start_button_pressed and uploaded_file is not None:
# Call the function to track uploaded video with the stop button state
track_uploaded_video(uploaded_file, stop_button_pressed, frame_placeholder)
# Release resources
if uploaded_file:
uploaded_file.close()
# Function to perform object tracking on uploaded video
def track_uploaded_video(video_file, stop_button, frame_placeholder):
# Create a temporary file to save the uploaded video
temp_video = tempfile.NamedTemporaryFile(delete=False)
temp_video.write(video_file.read())
temp_video.close()
# OpenCV's VideoCapture for reading video file
cap = cv2.VideoCapture(temp_video.name)
frame_count = 0
while cap.isOpened() and not stop_button:
ret, frame = cap.read()
if not ret:
st.write("The video capture has ended.")
break
# Process every 5th frame
if frame_count % 5 == 0:
# Resize frame to reduce processing time
frame_resized = cv2.resize(frame, (640, 480))
# Detect and track objects using YOLOv8
results = model.track(frame_resized, persist=True)
# Plot results
frame_ = results[0].plot()
# Display frame with bounding boxes
frame_placeholder.image(frame_, channels="BGR")
frame_count += 1
# Release resources
cap.release()
# Remove temporary file
os.remove(temp_video.name)
# Run the app
if __name__ == "__main__":
main()
|