File size: 5,457 Bytes
8afd9ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import cv2
import numpy as np
import os
import torch
import onnxruntime as ort
import time
from functools import wraps
import argparse
from PIL import Image
from io import BytesIO
import streamlit as st

# Parse command-line arguments
#parser = argparse.ArgumentParser()
#parser.add_argument("--mosaic", help="Enable mosaic processing mode", action="store_true")
#args = parser.parse_args()
#mosaic = args.mosaic  # Set this based on your command line argument

# For streamlit use let's just set mosaic to "true", but I'm leavind the command-line arg here for anyone to use

mosaic = True

def center_crop(img, new_height, new_width):
    height, width, _ = img.shape
    start_x = width//2 - new_width//2
    start_y = height//2 - new_height//2
    return img[start_y:start_y+new_height, start_x:start_x+new_width]


def mosaic_crop(img, size):
    height, width, _ = img.shape
    padding_height = (size - height % size) % size
    padding_width = (size - width % size) % size

    padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0])
    tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)]

    return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width

def stitch_tiles(tiles, rows, cols, size):
    return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0)


def timing_decorator(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()

        duration = end_time - start_time
        print(f"Function '{func.__name__}' took {duration:.6f} seconds")
        return result

    return wrapper

@timing_decorator
def process_image(session, img, colors, mosaic=False):
    if not mosaic:
        # Crop the center of the image to 416x416 pixels
        img = center_crop(img, 416, 416)
    blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False)

    # Perform inference
    output = session.run(None, {session.get_inputs()[0].name: blob})

    # Assuming the output is a probability map where higher values indicate higher probability of a class
    output_img = output[0].squeeze(0).transpose(1, 2, 0)
    output_img = (output_img * 122).clip(0, 255).astype(np.uint8)
    output_mask = output_img.max(axis=2)

    output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8)

    # Assign specific colors to the classes in the mask
    for class_idx in np.unique(output_mask):
        if class_idx in colors:
            output_mask_color[output_mask == class_idx] = colors[class_idx]

    # Mask for the transparent class
    transparent_mask = (output_mask == 122)

    # Convert the mask to a 3-channel image
    transparent_mask = np.stack([transparent_mask]*3, axis=-1)

    # Where the mask is True, set the output color image to the input image
    output_mask_color[transparent_mask] = img[transparent_mask]

    # Make the colorful mask semi-transparent
    overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0)

    return overlay
 

# set cuda = true if you have an NVIDIA GPU
cuda = torch.cuda.is_available()

if cuda: 
    print("We have a GPU!")
providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider']

session = ort.InferenceSession('end2end.onnx', providers=providers)


# Define colors for classes 0, 122 and 244
colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)}  # Red, Black, Yellow

def load_image(uploaded_file):
    try:
        image = Image.open(uploaded_file)
        return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
    except Exception as e:
        st.write("Could not load image: ", e)
        return None


st.title("OpenLander ONNX app")
st.write("Upload an image to process with the ONNX OpenLander model!")
st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.")


uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
if uploaded_file is not None:
    img = load_image(uploaded_file)
    if img.shape[2] == 4:
        img = img[:, :, :3]  # Drop the alpha channel if it exists
    img_processed = None

    if st.button('Process'):
        with st.spinner('Processing...'):
            start = time.time()
            if mosaic:
                tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416)
                processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles]
                overlay = stitch_tiles(processed_tiles, rows, cols, 416)

                # Crop the padding back out
                overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width]
                img_processed = overlay
            else:
                img_processed = process_image(session, img, colors)
            end = time.time()
            st.write(f"Processing time: {end - start} seconds")

    st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True)

    if img_processed is not None:
        st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True)
        st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ")