|
|
|
|
|
""" |
|
------------------------------------------------- |
|
@File Name: utils.py |
|
@Author: Luyao.zhang |
|
@Date: 2023/5/16 |
|
@Description: |
|
------------------------------------------------- |
|
""" |
|
from ultralytics import YOLO |
|
import streamlit as st |
|
import cv2 |
|
from PIL import Image |
|
import tempfile |
|
|
|
|
|
def _display_detected_frames(conf, model, st_frame, image): |
|
""" |
|
Display the detected objects on a video frame using the YOLOv8 model. |
|
:param conf (float): Confidence threshold for object detection. |
|
:param model (YOLOv8): An instance of the `YOLOv8` class containing the YOLOv8 model. |
|
:param st_frame (Streamlit object): A Streamlit object to display the detected video. |
|
:param image (numpy array): A numpy array representing the video frame. |
|
:return: None |
|
""" |
|
|
|
image = cv2.resize(image, (720, int(720 * (9 / 16)))) |
|
|
|
|
|
res = model.predict(image, conf=conf) |
|
|
|
|
|
res_plotted = res[0].plot() |
|
st_frame.image(res_plotted, |
|
caption='Detected Video', |
|
channels="BGR", |
|
use_column_width=True |
|
) |
|
|
|
|
|
@st.cache_resource |
|
def load_model(model_path): |
|
""" |
|
Loads a YOLO object detection model from the specified model_path. |
|
|
|
Parameters: |
|
model_path (str): The path to the YOLO model file. |
|
|
|
Returns: |
|
A YOLO object detection model. |
|
""" |
|
model = YOLO(model_path) |
|
return model |
|
|
|
|
|
def infer_uploaded_image(conf, model): |
|
""" |
|
Execute inference for uploaded image |
|
:param conf: Confidence of YOLOv8 model |
|
:param model: An instance of the `YOLOv8` class containing the YOLOv8 model. |
|
:return: None |
|
""" |
|
source_img = st.sidebar.file_uploader( |
|
label="Choose an image...", |
|
type=("jpg", "jpeg", "png", 'bmp', 'webp') |
|
) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
if source_img: |
|
uploaded_image = Image.open(source_img) |
|
|
|
st.image( |
|
image=source_img, |
|
caption="Uploaded Image", |
|
use_column_width=True |
|
) |
|
|
|
if source_img: |
|
if st.button("Detect"): |
|
with st.spinner("Running..."): |
|
res = model.predict(uploaded_image, |
|
conf=conf) |
|
boxes = res[0].boxes |
|
dclasses = boxes.cls.numpy(force=True) |
|
names = res[0].names |
|
|
|
dnames = [] |
|
ndetected = {} |
|
for cid in dclasses: |
|
if names[cid] in ndetected: |
|
count = ndetected[names[cid]] |
|
else: |
|
count = 0 |
|
ndetected.update( {names[cid]:count+1} ) |
|
dnames.append(names[cid]) |
|
|
|
res_plotted = res[0].plot()[:, :, ::-1] |
|
|
|
with col2: |
|
st.image(res_plotted, |
|
caption="{} classes detected".format(len(dnames)), |
|
use_column_width=True) |
|
try: |
|
with st.expander("Detection Results"): |
|
|
|
|
|
for cname in ndetected: |
|
st.write("{} x{}".format(cname, ndetected[cname])) |
|
except Exception as ex: |
|
st.write("No image is uploaded yet!") |
|
st.write(ex) |
|
|
|
|
|
def infer_uploaded_video(conf, model): |
|
""" |
|
Execute inference for uploaded video |
|
:param conf: Confidence of YOLOv8 model |
|
:param model: An instance of the `YOLOv8` class containing the YOLOv8 model. |
|
:return: None |
|
""" |
|
source_video = st.sidebar.file_uploader( |
|
label="Choose a video..." |
|
) |
|
|
|
if source_video: |
|
st.video(source_video) |
|
|
|
if source_video: |
|
if st.button("Detect"): |
|
with st.spinner("Running..."): |
|
try: |
|
tfile = tempfile.NamedTemporaryFile() |
|
tfile.write(source_video.read()) |
|
vid_cap = cv2.VideoCapture( |
|
tfile.name) |
|
st_frame = st.empty() |
|
while (vid_cap.isOpened()): |
|
success, image = vid_cap.read() |
|
if success: |
|
_display_detected_frames(conf, |
|
model, |
|
st_frame, |
|
image |
|
) |
|
else: |
|
vid_cap.release() |
|
break |
|
except Exception as e: |
|
st.error(f"Error loading video: {e}") |
|
|
|
|
|
def infer_uploaded_webcam(conf, model): |
|
""" |
|
Execute inference for webcam. |
|
:param conf: Confidence of YOLOv8 model |
|
:param model: An instance of the `YOLOv8` class containing the YOLOv8 model. |
|
:return: None |
|
""" |
|
try: |
|
flag = st.button( |
|
label="Stop running" |
|
) |
|
vid_cap = cv2.VideoCapture(0) |
|
st_frame = st.empty() |
|
while not flag: |
|
success, image = vid_cap.read() |
|
if success: |
|
_display_detected_frames( |
|
conf, |
|
model, |
|
st_frame, |
|
image |
|
) |
|
else: |
|
vid_cap.release() |
|
break |
|
except Exception as e: |
|
st.error(f"Error loading video: {str(e)}") |
|
|