|
import cv2 |
|
import numpy as np |
|
import os |
|
import torch |
|
import onnxruntime as ort |
|
import time |
|
from functools import wraps |
|
import argparse |
|
from PIL import Image |
|
from io import BytesIO |
|
import streamlit as st |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mosaic = True |
|
|
|
def center_crop(img, new_height, new_width): |
|
height, width, _ = img.shape |
|
start_x = width//2 - new_width//2 |
|
start_y = height//2 - new_height//2 |
|
return img[start_y:start_y+new_height, start_x:start_x+new_width] |
|
|
|
|
|
def mosaic_crop(img, size): |
|
height, width, _ = img.shape |
|
padding_height = (size - height % size) % size |
|
padding_width = (size - width % size) % size |
|
|
|
padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0]) |
|
tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)] |
|
|
|
return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width |
|
|
|
def stitch_tiles(tiles, rows, cols, size): |
|
return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0) |
|
|
|
|
|
def timing_decorator(func): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
start_time = time.time() |
|
result = func(*args, **kwargs) |
|
end_time = time.time() |
|
|
|
duration = end_time - start_time |
|
print(f"Function '{func.__name__}' took {duration:.6f} seconds") |
|
return result |
|
|
|
return wrapper |
|
|
|
@timing_decorator |
|
def process_image(session, img, colors, mosaic=False): |
|
if not mosaic: |
|
|
|
img = center_crop(img, 416, 416) |
|
blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False) |
|
|
|
|
|
output = session.run(None, {session.get_inputs()[0].name: blob}) |
|
|
|
|
|
output_img = output[0].squeeze(0).transpose(1, 2, 0) |
|
output_img = (output_img * 122).clip(0, 255).astype(np.uint8) |
|
output_mask = output_img.max(axis=2) |
|
|
|
output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8) |
|
|
|
|
|
for class_idx in np.unique(output_mask): |
|
if class_idx in colors: |
|
output_mask_color[output_mask == class_idx] = colors[class_idx] |
|
|
|
|
|
transparent_mask = (output_mask == 122) |
|
|
|
|
|
transparent_mask = np.stack([transparent_mask]*3, axis=-1) |
|
|
|
|
|
output_mask_color[transparent_mask] = img[transparent_mask] |
|
|
|
|
|
overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0) |
|
|
|
return overlay |
|
|
|
|
|
|
|
cuda = torch.cuda.is_available() |
|
|
|
if cuda: |
|
print("We have a GPU!") |
|
providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider'] |
|
|
|
session = ort.InferenceSession('end2end.onnx', providers=providers) |
|
|
|
|
|
|
|
colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)} |
|
|
|
def load_image(uploaded_file): |
|
try: |
|
image = Image.open(uploaded_file) |
|
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
|
except Exception as e: |
|
st.write("Could not load image: ", e) |
|
return None |
|
|
|
|
|
st.title("OpenLander ONNX app") |
|
st.write("Upload an image to process with the ONNX OpenLander model!") |
|
st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"]) |
|
if uploaded_file is not None: |
|
img = load_image(uploaded_file) |
|
if img.shape[2] == 4: |
|
img = img[:, :, :3] |
|
img_processed = None |
|
|
|
if st.button('Process'): |
|
with st.spinner('Processing...'): |
|
start = time.time() |
|
if mosaic: |
|
tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416) |
|
processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles] |
|
overlay = stitch_tiles(processed_tiles, rows, cols, 416) |
|
|
|
|
|
overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width] |
|
img_processed = overlay |
|
else: |
|
img_processed = process_image(session, img, colors) |
|
end = time.time() |
|
st.write(f"Processing time: {end - start} seconds") |
|
|
|
st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True) |
|
|
|
if img_processed is not None: |
|
st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True) |
|
st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ") |
|
|