File size: 4,271 Bytes
6d35603
 
 
 
 
 
 
 
 
 
 
 
 
 
9f0b3a7
 
6907110
6d35603
ab634f0
 
 
6d35603
 
 
 
 
 
 
ab634f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d35603
 
 
 
 
 
ab634f0
 
 
 
6d35603
 
 
 
 
 
9f0b3a7
ab634f0
9f0b3a7
6d35603
 
 
 
 
 
 
 
 
 
 
 
 
9f0b3a7
6907110
9f0b3a7
6907110
6d35603
ab634f0
6d35603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ab634f0
6d35603
 
 
 
 
 
6907110
6d35603
 
 
 
 
 
 
 
 
 
6907110
ab634f0
6d35603
 
 
 
 
 
 
 
6907110
6d35603
 
 
6907110
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import streamlit as st
import cv2
import numpy as np
import onnxruntime as ort
from PIL import Image
import tempfile

# Load the ONNX model
@st.cache_resource
def load_model():
    return ort.InferenceSession("model.onnx")

ort_session = load_model()

# Define class names and their corresponding indices
CLASS_NAMES = {0: 'car', 1: 'license_plate'}

def preprocess_image(image, target_size=(640, 640)):
    if isinstance(image, Image.Image):
        image = np.array(image)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    image = cv2.resize(image, target_size)
    image = image.astype(np.float32) / 255.0
    image = np.transpose(image, (2, 0, 1))
    image = np.expand_dims(image, axis=0)
    return image

def postprocess_results(output, image_shape, confidence_threshold=0.25, iou_threshold=0.45):
    if isinstance(output, (list, tuple)):
        predictions = output[0]
    elif isinstance(output, np.ndarray):
        predictions = output
    else:
        raise ValueError(f"Unexpected output type: {type(output)}")

    if len(predictions.shape) == 4:
        predictions = predictions.squeeze((0, 1))
    elif len(predictions.shape) == 3:
        predictions = predictions.squeeze(0)

    boxes = predictions[:, :4]
    scores = predictions[:, 4]
    class_ids = predictions[:, 5]

    mask = scores > confidence_threshold
    boxes = boxes[mask]
    scores = scores[mask]
    class_ids = class_ids[mask]

    boxes[:, 2:] += boxes[:, :2]
    boxes[:, [0, 2]] *= image_shape[1]
    boxes[:, [1, 3]] *= image_shape[0]

    indices = cv2.dnn.NMSBoxes(boxes.tolist(), scores.tolist(), confidence_threshold, iou_threshold)

    results = []
    for i in indices:
        box = boxes[i]
        score = scores[i]
        class_id = int(class_ids[i])
        x1, y1, x2, y2 = map(int, box)
        results.append((x1, y1, x2, y2, float(score), class_id))

    return results

def process_image(image):
    orig_image = image.copy()
    processed_image = preprocess_image(image)
    
    inputs = {ort_session.get_inputs()[0].name: processed_image}
    outputs = ort_session.run(None, inputs)
    
    results = postprocess_results(outputs, image.shape)
    
    for x1, y1, x2, y2, score, class_id in results:
        color = (0, 255, 0) if CLASS_NAMES.get(class_id, 'unknown') == 'car' else (255, 0, 0)
        cv2.rectangle(orig_image, (x1, y1), (x2, y2), color, 2)
        label = f"{CLASS_NAMES.get(class_id, 'unknown')}: {score:.2f}"
        cv2.putText(orig_image, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
    
    return cv2.cvtColor(orig_image, cv2.COLOR_BGR2RGB)

def process_video(video_path):
    cap = cv2.VideoCapture(video_path)
    
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    
    temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
    out = cv2.VideoWriter(temp_file.name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        processed_frame = process_image(frame)
        out.write(cv2.cvtColor(processed_frame, cv2.COLOR_RGB2BGR))
    
    cap.release()
    out.release()
    
    return temp_file.name

st.title("Vehicle and License Plate Detection")

uploaded_file = st.file_uploader("Choose an image or video file", type=["jpg", "jpeg", "png", "mp4"])

if uploaded_file is not None:
    file_type = uploaded_file.type.split('/')[0]
    
    if file_type == "image":
        image = Image.open(uploaded_file)
        st.image(image, caption="Uploaded Image", use_column_width=True)
        
        if st.button("Detect Objects"):
            processed_image = process_image(np.array(image))
            st.image(processed_image, caption="Processed Image", use_column_width=True)
    
    elif file_type == "video":
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(uploaded_file.read())
        
        st.video(tfile.name)
        
        if st.button("Detect Objects"):
            processed_video = process_video(tfile.name)
            st.video(processed_video)

st.write("Upload an image or video to detect vehicles and license plates.")