Spaces:
Runtime error
Runtime error
import streamlit as st | |
import cv2 | |
import numpy as np | |
import onnxruntime as ort | |
from PIL import Image | |
import tempfile | |
# Load the ONNX model | |
def load_model(): | |
return ort.InferenceSession("model.onnx") | |
ort_session = load_model() | |
# Define class names and their corresponding indices | |
CLASS_NAMES = {0: 'car', 1: 'license_plate'} | |
def preprocess_image(image, target_size=(640, 640)): | |
if isinstance(image, Image.Image): | |
image = np.array(image) | |
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
image = cv2.resize(image, target_size) | |
image = image.astype(np.float32) / 255.0 | |
image = np.transpose(image, (2, 0, 1)) | |
image = np.expand_dims(image, axis=0) | |
return image | |
def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_threshold=0.45): | |
if isinstance(output, (list, tuple)): | |
predictions = output[0] | |
elif isinstance(output, np.ndarray): | |
predictions = output | |
else: | |
raise ValueError(f"Unexpected output type: {type(output)}") | |
if len(predictions.shape) == 4: | |
predictions = predictions.squeeze((0, 1)) | |
elif len(predictions.shape) == 3: | |
predictions = predictions.squeeze(0) | |
boxes = predictions[:, :4] | |
scores = predictions[:, 4] | |
class_ids = predictions[:, 5] | |
mask = scores > confidence_threshold | |
boxes = boxes[mask] | |
scores = scores[mask] | |
class_ids = class_ids[mask] | |
boxes[:, 2:] += boxes[:, :2] | |
boxes[:, [0, 2]] *= image_shape[1] | |
boxes[:, [1, 3]] *= image_shape[0] | |
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold) | |
results = [] | |
for i in indices: | |
box = boxes[i] | |
score = scores[i] | |
class_id = int(class_ids[i]) | |
x1, y1, x2, y2 = map(int, box) | |
results.append((x1, y1, x2, y2, float(score), class_id)) | |
return results | |
def process_image(image): | |
orig_image = image.copy() | |
processed_image = preprocess_image(image) | |
inputs = {ort_session.get_inputs()[0].name: processed_image} | |
outputs = ort_session.run(None, inputs) | |
results = postprocess_results(outputs, image.shape) | |
for x1, y1, x2, y2, score, class_id in results: | |
color = (0, 255, 0) if CLASS_NAMES.get(class_id, 'unknown') == 'car' else (255, 0, 0) | |
cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2) | |
label = f"{CLASS_NAMES.get(class_id, 'unknown')}: {score:.2f}" | |
cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2) | |
return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB) | |
def process_video(video_path): | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
processed_frame = process_image(frame) | |
out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR)) | |
cap.release() | |
out.release() | |
return temp_file.name | |
st.title("Vehicle and License Plate Detection") | |
uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"]) | |
if uploaded_file is not None: | |
file_type = uploaded_file.type.split('/')[0] | |
if file_type == "image": | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
if st.button("Detect Objects"): | |
processed_image = process_image(np.array(image)) | |
st.image(processed_image, caption="Processed Image", use_column_width=True) | |
elif file_type == "video": | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(uploaded_file.read()) | |
st.video(tfile.name) | |
if st.button("Detect Objects"): | |
processed_video = process_video(tfile.name) | |
st.video(processed_video) | |
st.write("Upload an image or video to detect vehicles and license plates.") |