Spaces:
Runtime error
Runtime error
import streamlit as st | |
import cv2 | |
import numpy as np | |
import onnxruntime as ort | |
from PIL import Image | |
import tempfile | |
import torch | |
from ultralytics import YOLO | |
# Load models | |
def load_models(): | |
license_plate_detector = YOLO('license_plate_detector.pt') | |
vehicle_detector = YOLO('yolov8n.pt') | |
ort_session = ort.InferenceSession("model.onnx") | |
return license_plate_detector, vehicle_detector, ort_session | |
def draw_border(img, top_left, bottom_right, color=(0, 255, 0), thickness=10, line_length_x=200, line_length_y=200): | |
x1, y1 = top_left | |
x2, y2 = bottom_right | |
# Draw corner lines | |
cv2.line(img, (x1, y1), (x1, y1 + line_length_y), color, thickness) # top-left | |
cv2.line(img, (x1, y1), (x1 + line_length_x, y1), color, thickness) | |
cv2.line(img, (x1, y2), (x1, y2 - line_length_y), color, thickness) # bottom-left | |
cv2.line(img, (x1, y2), (x1 + line_length_x, y2), color, thickness) | |
cv2.line(img, (x2, y1), (x2 - line_length_x, y1), color, thickness) # top-right | |
cv2.line(img, (x2, y1), (x2, y1 + line_length_y), color, thickness) | |
cv2.line(img, (x2, y2), (x2, y2 - line_length_y), color, thickness) # bottom-right | |
cv2.line(img, (x2, y2), (x2 - line_length_x, y2), color, thickness) | |
return img | |
def process_frame(frame, license_plate_detector, vehicle_detector, ort_session): | |
# Detect vehicles | |
vehicle_results = vehicle_detector(frame, classes=[2, 3, 5, 7]) # cars, motorcycles, bus, trucks | |
# Process each vehicle | |
for vehicle in vehicle_results[0].boxes.data: | |
x1, y1, x2, y2, score, class_id = vehicle | |
if score > 0.5: # Confidence threshold | |
# Draw vehicle border | |
draw_border(frame, | |
(int(x1), int(y1)), | |
(int(x2), int(y2)), | |
color=(0, 255, 0), | |
thickness=25, | |
line_length_x=200, | |
line_length_y=200) | |
# Detect license plate in vehicle region | |
vehicle_crop = frame[int(y1):int(y2), int(x1):int(x2)] | |
license_results = license_plate_detector(vehicle_crop) | |
for license_plate in license_results[0].boxes.data: | |
lp_x1, lp_y1, lp_x2, lp_y2, lp_score, _ = license_plate | |
if lp_score > 0.5: | |
# Adjust coordinates to full frame | |
abs_lp_x1 = int(x1 + lp_x1) | |
abs_lp_y1 = int(y1 + lp_y1) | |
abs_lp_x2 = int(x1 + lp_x2) | |
abs_lp_y2 = int(y1 + lp_y2) | |
# Draw license plate box | |
cv2.rectangle(frame, | |
(abs_lp_x1, abs_lp_y1), | |
(abs_lp_x2, abs_lp_y2), | |
(0, 0, 255), 12) | |
# Extract and process license plate for OCR | |
license_crop = frame[abs_lp_y1:abs_lp_y2, abs_lp_x1:abs_lp_x2] | |
if license_crop.size > 0: | |
# Prepare license crop for ONNX model | |
license_crop_resized = cv2.resize(license_crop, (640, 640)) | |
license_crop_processed = np.transpose(license_crop_resized, (2, 0, 1)).astype(np.float32) / 255.0 | |
license_crop_processed = np.expand_dims(license_crop_processed, axis=0) | |
# Run OCR inference | |
try: | |
inputs = {ort_session.get_inputs()[0].name: license_crop_processed} | |
outputs = ort_session.run(None, inputs) | |
# Process OCR output (adjust based on your model's output format) | |
# This is a placeholder - adjust based on your ONNX model's output | |
license_number = "ABC123" # Replace with actual OCR processing | |
# Display license plate number | |
H, W, _ = license_crop.shape | |
license_crop_display = cv2.resize(license_crop, (int(W * 400 / H), 400)) | |
try: | |
# Display license crop and number above vehicle | |
h_crop, w_crop, _ = license_crop_display.shape | |
center_x = int((x1 + x2) / 2) | |
# Display license plate crop | |
frame[int(y1) - h_crop - 100:int(y1) - 100, | |
int(center_x - w_crop/2):int(center_x + w_crop/2)] = license_crop_display | |
# White background for text | |
cv2.rectangle(frame, | |
(int(center_x - w_crop/2), int(y1) - h_crop - 400), | |
(int(center_x + w_crop/2), int(y1) - h_crop - 100), | |
(255, 255, 255), | |
-1) | |
# Draw license number | |
(text_width, text_height), _ = cv2.getTextSize( | |
license_number, | |
cv2.FONT_HERSHEY_SIMPLEX, | |
4.3, | |
17) | |
cv2.putText(frame, | |
license_number, | |
(int(center_x - text_width/2), int(y1 - h_crop - 250 + text_height/2)), | |
cv2.FONT_HERSHEY_SIMPLEX, | |
4.3, | |
(0, 0, 0), | |
17) | |
except Exception as e: | |
st.error(f"Error displaying results: {str(e)}") | |
except Exception as e: | |
st.error(f"Error in OCR processing: {str(e)}") | |
return frame | |
def process_video(video_path, license_plate_detector, vehicle_detector, ort_session): | |
cap = cv2.VideoCapture(video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') | |
out = cv2.VideoWriter(temp_file.name, | |
cv2.VideoWriter_fourcc(*'mp4v'), | |
fps, | |
(width, height)) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
progress_bar = st.progress(0) | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
processed_frame = process_frame(frame, license_plate_detector, vehicle_detector, ort_session) | |
out.write(processed_frame) | |
frame_count += 1 | |
progress_bar.progress(frame_count / total_frames) | |
cap.release() | |
out.release() | |
progress_bar.empty() | |
return temp_file.name | |
# Streamlit UI | |
st.title("Advanced Vehicle and License Plate Detection") | |
try: | |
license_plate_detector, vehicle_detector, ort_session = load_models() | |
uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"]) | |
if uploaded_file is not None: | |
file_type = uploaded_file.type.split('/')[0] | |
if file_type == "image": | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
if st.button("Detect"): | |
with st.spinner("Processing image..."): | |
# Convert PIL Image to CV2 format | |
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
processed_image = process_frame(image_cv, license_plate_detector, vehicle_detector, ort_session) | |
processed_image = cv2.cvtColor(processed_image, cv2.COLOR_BGR2RGB) | |
st.image(processed_image, caption="Processed Image", use_column_width=True) | |
elif file_type == "video": | |
tfile = tempfile.NamedTemporaryFile(delete=False) | |
tfile.write(uploaded_file.read()) | |
st.video(tfile.name) | |
if st.button("Detect"): | |
with st.spinner("Processing video..."): | |
processed_video = process_video(tfile.name, license_plate_detector, vehicle_detector, ort_session) | |
st.video(processed_video) | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") |