OpenLanderONNX / app.py
StephanST's picture
first commit
8afd9ad
raw
history blame
5.46 kB
import cv2
import numpy as np
import os
import torch
import onnxruntime as ort
import time
from functools import wraps
import argparse
from PIL import Image
from io import BytesIO
import streamlit as st
# Parse command-line arguments
#parser = argparse.ArgumentParser()
#parser.add_argument("--mosaic", help="Enable mosaic processing mode", action="store_true")
#args = parser.parse_args()
#mosaic = args.mosaic # Set this based on your command line argument
# For streamlit use let's just set mosaic to "true", but I'm leavind the command-line arg here for anyone to use
mosaic = True
def center_crop(img, new_height, new_width):
height, width, _ = img.shape
start_x = width//2 - new_width//2
start_y = height//2 - new_height//2
return img[start_y:start_y+new_height, start_x:start_x+new_width]
def mosaic_crop(img, size):
height, width, _ = img.shape
padding_height = (size - height % size) % size
padding_width = (size - width % size) % size
padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0])
tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)]
return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width
def stitch_tiles(tiles, rows, cols, size):
return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0)
def timing_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
print(f"Function '{func.__name__}' took {duration:.6f} seconds")
return result
return wrapper
@timing_decorator
def process_image(session, img, colors, mosaic=False):
if not mosaic:
# Crop the center of the image to 416x416 pixels
img = center_crop(img, 416, 416)
blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False)
# Perform inference
output = session.run(None, {session.get_inputs()[0].name: blob})
# Assuming the output is a probability map where higher values indicate higher probability of a class
output_img = output[0].squeeze(0).transpose(1, 2, 0)
output_img = (output_img * 122).clip(0, 255).astype(np.uint8)
output_mask = output_img.max(axis=2)
output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8)
# Assign specific colors to the classes in the mask
for class_idx in np.unique(output_mask):
if class_idx in colors:
output_mask_color[output_mask == class_idx] = colors[class_idx]
# Mask for the transparent class
transparent_mask = (output_mask == 122)
# Convert the mask to a 3-channel image
transparent_mask = np.stack([transparent_mask]*3, axis=-1)
# Where the mask is True, set the output color image to the input image
output_mask_color[transparent_mask] = img[transparent_mask]
# Make the colorful mask semi-transparent
overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0)
return overlay
# set cuda = true if you have an NVIDIA GPU
cuda = torch.cuda.is_available()
if cuda:
print("We have a GPU!")
providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = ort.InferenceSession('end2end.onnx', providers=providers)
# Define colors for classes 0, 122 and 244
colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)} # Red, Black, Yellow
def load_image(uploaded_file):
try:
image = Image.open(uploaded_file)
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
except Exception as e:
st.write("Could not load image: ", e)
return None
st.title("OpenLander ONNX app")
st.write("Upload an image to process with the ONNX OpenLander model!")
st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
if uploaded_file is not None:
img = load_image(uploaded_file)
if img.shape[2] == 4:
img = img[:, :, :3] # Drop the alpha channel if it exists
img_processed = None
if st.button('Process'):
with st.spinner('Processing...'):
start = time.time()
if mosaic:
tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416)
processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles]
overlay = stitch_tiles(processed_tiles, rows, cols, 416)
# Crop the padding back out
overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width]
img_processed = overlay
else:
img_processed = process_image(session, img, colors)
end = time.time()
st.write(f"Processing time: {end - start} seconds")
st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True)
if img_processed is not None:
st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True)
st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ")