import cv2 |
import numpy as np |
import os |
import torch |
import onnxruntime as ort |
import time |
from functools import wraps |
import argparse |
from PIL import Image |
from io import BytesIO |
import streamlit as st |
mosaic = True |
def center_crop(img, new_height, new_width): |
height, width, _ = img.shape |
start_x = width//2 - new_width//2 |
start_y = height//2 - new_height//2 |
return img[start_y:start_y+new_height, start_x:start_x+new_width] |
def mosaic_crop(img, size): |
height, width, _ = img.shape |
padding_height = (size - height % size) % size |
padding_width = (size - width % size) % size |
padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0]) |
tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)] |
return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width |
def stitch_tiles(tiles, rows, cols, size): |
return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0) |
def timing_decorator(func): |
@wraps(func) |
def wrapper(*args, **kwargs): |
start_time = time.time() |
result = func(*args, **kwargs) |
end_time = time.time() |
duration = end_time - start_time |
print(f"Function '{func.__name__}' took {duration:.6f} seconds") |
return result |
return wrapper |
@timing_decorator |
def process_image(session, img, colors, mosaic=False): |
if not mosaic: |
img = center_crop(img, 416, 416) |
blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False) |
output = session.run(None, {session.get_inputs()[0].name: blob}) |
output_img = output[0].squeeze(0).transpose(1, 2, 0) |
output_img = (output_img * 122).clip(0, 255).astype(np.uint8) |
output_mask = output_img.max(axis=2) |
output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8) |
for class_idx in np.unique(output_mask): |
if class_idx in colors: |
output_mask_color[output_mask == class_idx] = colors[class_idx] |
transparent_mask = (output_mask == 122) |
transparent_mask = np.stack([transparent_mask]*3, axis=-1) |
output_mask_color[transparent_mask] = img[transparent_mask] |
overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0) |
return overlay |
cuda = torch.cuda.is_available() |
if cuda: |
print("We have a GPU!") |
providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider'] |
session = ort.InferenceSession('end2end.onnx', providers=providers) |
colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)} |
def load_image(uploaded_file): |
try: |
image = Image.open(uploaded_file) |
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) |
except Exception as e: |
st.write("Could not load image: ", e) |
return None |
st.title("OpenLander ONNX app") |
st.write("Upload an image to process with the ONNX OpenLander model!") |
st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.") |
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"]) |
if uploaded_file is not None: |
img = load_image(uploaded_file) |
if img.shape[2] == 4: |
img = img[:, :, :3] |
img_processed = None |
if st.button('Process'): |
with st.spinner('Processing...'): |
start = time.time() |
if mosaic: |
tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416) |
processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles] |
overlay = stitch_tiles(processed_tiles, rows, cols, 416) |
overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width] |
img_processed = overlay |
else: |
img_processed = process_image(session, img, colors) |
end = time.time() |
st.write(f"Processing time: {end - start} seconds") |
st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True) |
if img_processed is not None: |
st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True) |
st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ") |