File size: 4,271 Bytes
6d35603 9f0b3a7 6907110 6d35603 ab634f0 6d35603 ab634f0 6d35603 ab634f0 6d35603 9f0b3a7 ab634f0 9f0b3a7 6d35603 9f0b3a7 6907110 9f0b3a7 6907110 6d35603 ab634f0 6d35603 ab634f0 6d35603 6907110 6d35603 6907110 ab634f0 6d35603 6907110 6d35603 6907110 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import streamlit as st
import cv2
import numpy as np
import onnxruntime as ort
from PIL import Image
import tempfile
# Load the ONNX model
@st.cache_resource
def load_model():
return ort.InferenceSession("model.onnx")
ort_session = load_model()
# Define class names and their corresponding indices
CLASS_NAMES = {0: 'car', 1: 'license_plate'}
def preprocess_image(image, target_size=(640, 640)):
if isinstance(image, Image.Image):
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = cv2.resize(image, target_size)
image = image.astype(np.float32) / 255.0
image = np.transpose(image, (2, 0, 1))
image = np.expand_dims(image, axis=0)
return image
def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_threshold=0.45):
if isinstance(output, (list, tuple)):
predictions = output[0]
elif isinstance(output, np.ndarray):
predictions = output
else:
raise ValueError(f"Unexpected output type: {type(output)}")
if len(predictions.shape) == 4:
predictions = predictions.squeeze((0, 1))
elif len(predictions.shape) == 3:
predictions = predictions.squeeze(0)
boxes = predictions[:, :4]
scores = predictions[:, 4]
class_ids = predictions[:, 5]
mask = scores > confidence_threshold
boxes = boxes[mask]
scores = scores[mask]
class_ids = class_ids[mask]
boxes[:, 2:] += boxes[:, :2]
boxes[:, [0, 2]] *= image_shape[1]
boxes[:, [1, 3]] *= image_shape[0]
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold)
results = []
for i in indices:
box = boxes[i]
score = scores[i]
class_id = int(class_ids[i])
x1, y1, x2, y2 = map(int, box)
results.append((x1, y1, x2, y2, float(score), class_id))
return results
def process_image(image):
orig_image = image.copy()
processed_image = preprocess_image(image)
inputs = {ort_session.get_inputs()[0].name: processed_image}
outputs = ort_session.run(None, inputs)
results = postprocess_results(outputs, image.shape)
for x1, y1, x2, y2, score, class_id in results:
color = (0, 255, 0) if CLASS_NAMES.get(class_id, 'unknown') == 'car' else (255, 0, 0)
cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
label = f"{CLASS_NAMES.get(class_id, 'unknown')}: {score:.2f}"
cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)
def process_video(video_path):
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
processed_frame = process_image(frame)
out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
cap.release()
out.release()
return temp_file.name
st.title("Vehicle and License Plate Detection")
uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"])
if uploaded_file is not None:
file_type = uploaded_file.type.split('/')[0]
if file_type == "image":
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
if st.button("Detect Objects"):
processed_image = process_image(np.array(image))
st.image(processed_image, caption="Processed Image", use_column_width=True)
elif file_type == "video":
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_file.read())
st.video(tfile.name)
if st.button("Detect Objects"):
processed_video = process_video(tfile.name)
st.video(processed_video)
st.write("Upload an image or video to detect vehicles and license plates.") |