File size: 4,793 Bytes
8166792 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
from ultralytics import YOLO
import cv2
import gradio as gr
import numpy as np
import os
import torch
from image_segmenter import ImageSegmenter
from monocular_depth_estimator import MonocularDepthEstimator
# params
CANCEL_PROCESSING = False
img_seg = ImageSegmenter(model_type='n')
depth_estimator = MonocularDepthEstimator(side_by_side=False)
def process_image(image):
return img_seg.predict(image), depth_estimator.make_prediction(image)
def process_video(vid_path=None):
vid_cap = cv2.VideoCapture(vid_path)
while vid_cap.isOpened():
ret, frame = vid_cap.read()
if ret:
print("making predictions ....")
yield cv2.cvtColor(img_seg.predict(frame), cv2.COLOR_BGR2RGB), depth_estimator.make_prediction(frame)
return None
def update_segmentation_options(options):
img_seg.is_show_bounding_boxes = True if 'Show Boundary Box' in options else False
img_seg.is_show_segmentation = True if 'Show Segmentation Region' in options else False
img_seg.is_show_segmentation_boundary = True if 'Show Segmentation Boundary' in options else False
def update_confidence_threshold(thres_val):
img_seg.confidence_threshold = thres_val/100
def cancel():
CANCEL_PROCESSING = True
if __name__ == "__main__":
# img_1 = cv2.imread("assets/images/bus.jpg")
# pred_img = image_segmentation(img_1)
# cv2.imshow("output", pred_img)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# gradio gui app
with gr.Blocks() as my_app:
# title
gr.Markdown(
"""
# Object segmentation and depth estimation
Input an image or Video
""")
# tabs
with gr.Tab("Image"):
with gr.Row():
with gr.Column(scale=1):
img_input = gr.Image()
options_checkbox_img = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options")
conf_thres_img = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
submit_btn_img = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
segmentation_img_output = gr.Image(height=300, label="Segmentation")
depth_img_output = gr.Image(height=300, label="Depth Estimation")
gr.Markdown("## Sample Images")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/images/bus.jpg")],
inputs=img_input,
outputs=[segmentation_img_output, depth_img_output],
fn=process_image,
cache_examples=True,
)
with gr.Tab("Video"):
with gr.Row():
with gr.Column(scale=1):
vid_input = gr.Video()
options_checkbox_vid = gr.CheckboxGroup(["Show Boundary Box", "Show Segmentation Region", "Show Segmentation Boundary"], label="Options")
conf_thres_vid = gr.Slider(1, 100, value=60, label="Confidence Threshold", info="Choose the threshold above which objects should be detected")
with gr.Row():
cancel_btn = gr.Button(value="Cancel")
submit_btn_vid = gr.Button(value="Predict")
with gr.Column(scale=2):
with gr.Row():
segmentation_vid_output = gr.Image(height=400, label="Segmentation")
depth_vid_output = gr.Image(height=400, label="Depth Estimation")
gr.Markdown("## Sample Videos")
gr.Examples(
examples=[os.path.join(os.path.dirname(__file__), "assets/videos/input_video.mp4")],
inputs=vid_input,
# outputs=vid_output,
# fn=vid_segmenation,
)
# image tab logic
submit_btn_img.click(process_image, inputs=img_input, outputs=[segmentation_img_output, depth_img_output])
options_checkbox_img.change(update_segmentation_options, options_checkbox_img, [])
conf_thres_img.change(update_confidence_threshold, conf_thres_img, [])
# video tab logic
submit_btn_vid.click(process_video, inputs=vid_input, outputs=[segmentation_vid_output, depth_vid_output])
cancel_btn.click(cancel, inputs=[], outputs=[])
options_checkbox_vid.change(update_segmentation_options, options_checkbox_vid, [])
conf_thres_vid.change(update_confidence_threshold, conf_thres_vid, [])
my_app.queue(concurrency_count=5, max_size=20).launch() |