hifebole
first commit
d6890dc
import cv2
import streamlit as st
import numpy as np
import tempfile
import os
from ultralytics import YOLO
from streamlit_webrtc import (webrtc_streamer, VideoProcessorBase, WebRtcMode, RTCConfiguration)
import av
from turn import get_ice_servers
model = YOLO('yolov8n.pt')
# Global variable to store the latest frame with bounding boxes
cached_frame = None
frame_skip = 5 # Process every 5th frame
# # Define a custom video processor class inheriting from VideoProcessorBase
# class VideoProcessor(VideoProcessorBase):
# def __init__(self):
# self.model = model
# self.frame_skip = 10 # Class-level variable for frame skipping
# self.cached_frame = None # Class-level variable for cached frames
def recv(frame: av.VideoFrame) -> av.VideoFrame:
# Skip frames to reduce processing load
# global frame_skip, cached_frame
# if frame_skip > 0:
# frame_skip -= 1
# return frame
# Reset frame skip
# frame_skip = 5
# Convert frame to OpenCV format (BGR)
frame_bgr = frame.to_ndarray(format="bgr24")
# Resize frame to reduce processing time
frame_resized = cv2.resize(frame_bgr, (160, 120)) # Instead of 640x480
# # Detect and track objects using YOLOv8
# results = model.track(frame_resized, persist=True)
# # Plot results
# frame_annotated = results[0].plot()
# # Cache the annotated frame
# cached_frame = frame_annotated
# Process every nth frame
if frame_skip == 0:
# Reset the frame skip counter
frame_skip = 10
# Detect and track objects using YOLOv8
results = model.track(frame_resized, persist=True)
# Plot results
frame_annotated = results[0].plot()
# Cache the annotated frame
cached_frame = frame_annotated
else:
# Use the cached frame for skipped frames
frame_annotated = cached_frame if cached_frame is not None else frame_resized
frame_skip -= 1
# Convert frame back to RGB format
frame_rgb = cv2.cvtColor(frame_annotated, cv2.COLOR_BGR2RGB)
return av.VideoFrame.from_ndarray(frame_rgb, format="rgb24")
# Streamlit web app
def main():
# Set page title
st.set_page_config(page_title="Object Tracking with Streamlit")
# Streamlit web app
st.title("Object Tracking")
# Radio button for user selection
option = st.radio("Choose an option:", ("Live Stream", "Upload Video"))
if option == "Live Stream":
# Start the WebRTC stream with object tracking
# WebRTC streamer configuration
# Define RTC configuration for WebRTC
# RTC_CONFIGURATION = RTCConfiguration({
# "iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]
# })
# Start the WebRTC stream with object tracking
# webrtc_streamer(key="live-stream", video_frame_callback=recv,
# rtc_configuration=rtc_configuration, sendback_audio=False)
webrtc_streamer(key="live-stream",
#mode=WebRtcMode.SENDRECV,
video_frame_callback=recv,
rtc_configuration={"iceServers": get_ice_servers()},
media_stream_constraints={"video": True, "audio": False},
async_processing=True)
elif option == "Upload Video":
# File uploader for video upload
uploaded_file = st.file_uploader("Upload a video file", type=["mp4", "avi", "mov"])
# Button to start tracking
start_button_pressed = st.button("Start Tracking")
# Placeholder for video frame
frame_placeholder = st.empty()
# Button to stop tracking
stop_button_pressed = st.button("Stop")
# Check if the start button is pressed and file is uploaded
if start_button_pressed and uploaded_file is not None:
# Call the function to track uploaded video with the stop button state
track_uploaded_video(uploaded_file, stop_button_pressed, frame_placeholder)
# Release resources
if uploaded_file:
uploaded_file.close()
# Function to perform object tracking on uploaded video
def track_uploaded_video(video_file, stop_button, frame_placeholder):
# Create a temporary file to save the uploaded video
temp_video = tempfile.NamedTemporaryFile(delete=False)
temp_video.write(video_file.read())
temp_video.close()
# OpenCV's VideoCapture for reading video file
cap = cv2.VideoCapture(temp_video.name)
frame_count = 0
while cap.isOpened() and not stop_button:
ret, frame = cap.read()
if not ret:
st.write("The video capture has ended.")
break
# Process every 5th frame
if frame_count % 5 == 0:
# Resize frame to reduce processing time
frame_resized = cv2.resize(frame, (640, 480))
# Detect and track objects using YOLOv8
results = model.track(frame_resized, persist=True)
# Plot results
frame_ = results[0].plot()
# Display frame with bounding boxes
frame_placeholder.image(frame_, channels="BGR")
frame_count += 1
# Release resources
cap.release()
# Remove temporary file
os.remove(temp_video.name)
# Run the app
if __name__ == "__main__":
main()