|
import streamlit as st |
|
import cv2 |
|
import numpy as np |
|
import onnxruntime as ort |
|
from PIL import Image |
|
import tempfile |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
return ort.InferenceSession("model.onnx") |
|
|
|
ort_session = load_model() |
|
|
|
def preprocess_image(image, target_size=(640, 640)): |
|
|
|
image = cv2.resize(image, target_size) |
|
|
|
image = image.astype(np.float32) / 255.0 |
|
|
|
image = np.transpose(image, (2, 0, 1)) |
|
|
|
image = np.expand_dims(image, axis=0) |
|
return image |
|
|
|
def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_threshold=0.45): |
|
|
|
boxes = output[0] |
|
scores = output[1] |
|
class_ids = output[2] |
|
|
|
|
|
mask = scores > confidence_threshold |
|
boxes = boxes[mask] |
|
scores = scores[mask] |
|
class_ids = class_ids[mask] |
|
|
|
|
|
indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold) |
|
|
|
results = [] |
|
for i in indices: |
|
box = boxes[i] |
|
score = scores[i] |
|
class_id = class_ids[i] |
|
x, y, w, h = box |
|
x1 = int(x * image_shape[1]) |
|
y1 = int(y * image_shape[0]) |
|
x2 = int((x + w) * image_shape[1]) |
|
y2 = int((y + h) * image_shape[0]) |
|
results.append((x1, y1, x2, y2, score, class_id)) |
|
|
|
return results |
|
|
|
def process_image(image): |
|
orig_image = image.copy() |
|
processed_image = preprocess_image(image) |
|
|
|
|
|
inputs = {ort_session.get_inputs()[0].name: processed_image} |
|
outputs = ort_session.run(None, inputs) |
|
|
|
results = postprocess_results(outputs, image.shape) |
|
|
|
|
|
for x1, y1, x2, y2, score, class_id in results: |
|
cv2.rectangle(orig_image, (x1, y1), (x2, y2), (0, 255, 0), 2) |
|
label = f"License Plate: {score:.2f}" |
|
cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) |
|
|
|
return orig_image |
|
|
|
def process_video(video_path): |
|
cap = cv2.VideoCapture(video_path) |
|
|
|
|
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') |
|
out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) |
|
|
|
while cap.isOpened(): |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
|
|
processed_frame = process_image(frame) |
|
out.write(processed_frame) |
|
|
|
cap.release() |
|
out.release() |
|
|
|
return temp_file.name |
|
|
|
st.title("License Plate Detection") |
|
|
|
uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"]) |
|
|
|
if uploaded_file is not None: |
|
file_type = uploaded_file.type.split('/')[0] |
|
|
|
if file_type == "image": |
|
image = Image.open(uploaded_file) |
|
image = np.array(image) |
|
|
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
if st.button("Detect License Plates"): |
|
processed_image = process_image(image) |
|
st.image(processed_image, caption="Processed Image", use_column_width=True) |
|
|
|
elif file_type == "video": |
|
tfile = tempfile.NamedTemporaryFile(delete=False) |
|
tfile.write(uploaded_file.read()) |
|
|
|
st.video(tfile.name) |
|
|
|
if st.button("Detect License Plates"): |
|
processed_video = process_video(tfile.name) |
|
st.video(processed_video) |
|
|
|
st.write("Upload an image or video to detect license plates.") |