Image_Stitcher / app.py
ammariii08's picture
Update app.py
77e0dca verified
raw
history blame
4.93 kB
import gradio as gr
import cv2
import numpy as np
from typing import Union, List
from pathlib import Path
from PIL import Image
import torch
from ultralytics.utils.plotting import save_one_box
from ultralytics.engine.results import Results
from ultralytics import YOLOWorld
# Function to resize images
def resize_images(images, scale_percent=50):
resized_images = []
for img in images:
width = int(img.shape[1] * scale_percent / 100)
height = int(img.shape[0] * scale_percent / 100)
dim = (width, height)
resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
resized_images.append(resized)
return resized_images
# Function to stitch images
def stitch_images(image_paths, scale_percent=50):
images = [cv2.imread(path) for path in image_paths]
resized_images = resize_images(images, scale_percent)
stitcher = cv2.Stitcher_create()
status, stitched_image = stitcher.stitch(resized_images)
if status == cv2.Stitcher_OK:
print("Stitching successful!")
return stitched_image
else:
print(f"Stitching failed with status code: {status}")
return None
# YOLO detection function
def yolo_detect(
image: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor],
classes: List[str],
) -> np.ndarray:
detector = YOLOWorld("yolov8x-worldv2.pt")
detector.set_classes(classes)
results: List[Results] = detector.predict(image)
boxes = []
for result in results:
boxes.append(
save_one_box(result.cpu().boxes.xyxy, im=result.orig_img, save=False)
)
del detector
return boxes[0]
# Function to predict and process image
def predict(image, offset_inches):
try:
drawer_img = yolo_detect(image, ["box"])
shrunked_img = make_square(shrink_bbox(drawer_img, 0.8))
return shrunked_img
except:
raise Exception("Unable to DETECT DRAWER, please take another picture with different magnification level!")
# Function to shrink bounding box
def shrink_bbox(image: np.ndarray, shrink_factor: float):
height, width = image.shape[:2]
center_x, center_y = width // 2, height // 2
new_width = int(width * shrink_factor)
new_height = int(height * shrink_factor)
x1 = max(center_x - new_width // 2, 0)
y1 = max(center_y - new_height // 2, 0)
x2 = min(center_x + new_width // 2, width)
y2 = min(center_y + new_height // 2, height)
cropped_image = image[y1:y2, x1:x2]
return cropped_image
# Function to make image square
def make_square(img: np.ndarray):
height, width = img.shape[:2]
max_dim = max(height, width)
pad_height = (max_dim - height) // 2
pad_width = (max_dim - width) // 2
pad_height_extra = max_dim - height - 2 * pad_height
pad_width_extra = max_dim - width - 2 * pad_width
if len(img.shape) == 3:
padded = np.pad(
img,
((pad_height, pad_height + pad_height_extra),
(pad_width, pad_width + pad_width_extra),
(0, 0)),
mode="edge"
)
else:
padded = np.pad(
img,
((pad_height, pad_height + pad_height_extra),
(pad_width, pad_width + pad_width_extra)),
mode="edge"
)
return padded
# Main image processing function
def process_image(image_paths, scale_percent=50, offset_inches=1):
stitched_image = stitch_images(image_paths, scale_percent)
if stitched_image is not None:
try:
stitched_image_rgb = cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
final_image = predict(stitched_image_rgb, offset_inches)
return final_image
except Exception as e:
print(str(e))
return stitched_image
# Gradio interface function
def gradio_stitch_and_detect(image_files):
image_paths = [file.name for file in image_files]
result_image = process_image(image_paths, scale_percent=50)
if result_image is not None:
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(result_image_rgb)
pil_image.save("stitched_image.jpg", "JPEG")
return pil_image, "stitched_image.jpg"
return None, None
# Gradio interface
with gr.Blocks() as interface:
gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>")
gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>=== Upload the images you want to stitch ===</h3>")
image_upload = gr.Files(type="filepath", label="Upload Images")
stitch_button = gr.Button("Stitch", variant="primary")
stitched_image = gr.Image(type="pil", label="Stitched Image")
download_button = gr.File(label="Download Stitched Image")
stitch_button.click(gradio_stitch_and_detect, inputs=image_upload, outputs=[stitched_image, download_button])
interface.launch()