File size: 5,457 Bytes
8afd9ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import cv2
import numpy as np
import os
import torch
import onnxruntime as ort
import time
from functools import wraps
import argparse
from PIL import Image
from io import BytesIO
import streamlit as st
# Parse command-line arguments
#parser = argparse.ArgumentParser()
#parser.add_argument("--mosaic", help="Enable mosaic processing mode", action="store_true")
#args = parser.parse_args()
#mosaic = args.mosaic # Set this based on your command line argument
# For streamlit use let's just set mosaic to "true", but I'm leavind the command-line arg here for anyone to use
mosaic = True
def center_crop(img, new_height, new_width):
height, width, _ = img.shape
start_x = width//2 - new_width//2
start_y = height//2 - new_height//2
return img[start_y:start_y+new_height, start_x:start_x+new_width]
def mosaic_crop(img, size):
height, width, _ = img.shape
padding_height = (size - height % size) % size
padding_width = (size - width % size) % size
padded_img = cv2.copyMakeBorder(img, 0, padding_height, 0, padding_width, cv2.BORDER_CONSTANT, value=[0, 0, 0])
tiles = [padded_img[x:x+size, y:y+size] for x in range(0, padded_img.shape[0], size) for y in range(0, padded_img.shape[1], size)]
return tiles, padded_img.shape[0] // size, padded_img.shape[1] // size, padding_height, padding_width
def stitch_tiles(tiles, rows, cols, size):
return np.concatenate([np.concatenate([tiles[i*cols + j] for j in range(cols)], axis=1) for i in range(rows)], axis=0)
def timing_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
duration = end_time - start_time
print(f"Function '{func.__name__}' took {duration:.6f} seconds")
return result
return wrapper
@timing_decorator
def process_image(session, img, colors, mosaic=False):
if not mosaic:
# Crop the center of the image to 416x416 pixels
img = center_crop(img, 416, 416)
blob = cv2.dnn.blobFromImage(img, 1/255.0, (416, 416), swapRB=True, crop=False)
# Perform inference
output = session.run(None, {session.get_inputs()[0].name: blob})
# Assuming the output is a probability map where higher values indicate higher probability of a class
output_img = output[0].squeeze(0).transpose(1, 2, 0)
output_img = (output_img * 122).clip(0, 255).astype(np.uint8)
output_mask = output_img.max(axis=2)
output_mask_color = np.zeros((416, 416, 3), dtype=np.uint8)
# Assign specific colors to the classes in the mask
for class_idx in np.unique(output_mask):
if class_idx in colors:
output_mask_color[output_mask == class_idx] = colors[class_idx]
# Mask for the transparent class
transparent_mask = (output_mask == 122)
# Convert the mask to a 3-channel image
transparent_mask = np.stack([transparent_mask]*3, axis=-1)
# Where the mask is True, set the output color image to the input image
output_mask_color[transparent_mask] = img[transparent_mask]
# Make the colorful mask semi-transparent
overlay = cv2.addWeighted(img, 0.6, output_mask_color, 0.4, 0)
return overlay
# set cuda = true if you have an NVIDIA GPU
cuda = torch.cuda.is_available()
if cuda:
print("We have a GPU!")
providers = ['CUDAExecutionProvider'] if cuda else ['CPUExecutionProvider']
session = ort.InferenceSession('end2end.onnx', providers=providers)
# Define colors for classes 0, 122 and 244
colors = {0: (0, 0, 255), 122: (0, 0, 0), 244: (0, 255, 255)} # Red, Black, Yellow
def load_image(uploaded_file):
try:
image = Image.open(uploaded_file)
return cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
except Exception as e:
st.write("Could not load image: ", e)
return None
st.title("OpenLander ONNX app")
st.write("Upload an image to process with the ONNX OpenLander model!")
st.write("Bear in mind that this model is **much less refined** than the embedded models at the moment.")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"])
if uploaded_file is not None:
img = load_image(uploaded_file)
if img.shape[2] == 4:
img = img[:, :, :3] # Drop the alpha channel if it exists
img_processed = None
if st.button('Process'):
with st.spinner('Processing...'):
start = time.time()
if mosaic:
tiles, rows, cols, padding_height, padding_width = mosaic_crop(img, 416)
processed_tiles = [process_image(session, tile, colors, mosaic=True) for tile in tiles]
overlay = stitch_tiles(processed_tiles, rows, cols, 416)
# Crop the padding back out
overlay = overlay[:overlay.shape[0]-padding_height, :overlay.shape[1]-padding_width]
img_processed = overlay
else:
img_processed = process_image(session, img, colors)
end = time.time()
st.write(f"Processing time: {end - start} seconds")
st.image(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), caption='Uploaded Image.', use_column_width=True)
if img_processed is not None:
st.image(cv2.cvtColor(img_processed, cv2.COLOR_BGR2RGB), caption='Processed Image.', use_column_width=True)
st.write("Red => obstacle ||| Yellow => Human obstacle ||| no color => clear for landing or delivery ")
|