Spaces:
Running
Running
File size: 4,934 Bytes
8e8e66f 77e0dca 8e8e66f 77e0dca 8e8e66f 77e0dca 8e8e66f 77e0dca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import cv2
import numpy as np
from typing import Union, List
from pathlib import Path
from PIL import Image
import torch
from ultralytics.utils.plotting import save_one_box
from ultralytics.engine.results import Results
from ultralytics import YOLOWorld
# Function to resize images
def resize_images(images, scale_percent=50):
resized_images = []
for img in images:
width = int(img.shape[1] * scale_percent / 100)
height = int(img.shape[0] * scale_percent / 100)
dim = (width, height)
resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
resized_images.append(resized)
return resized_images
# Function to stitch images
def stitch_images(image_paths, scale_percent=50):
images = [cv2.imread(path) for path in image_paths]
resized_images = resize_images(images, scale_percent)
stitcher = cv2.Stitcher_create()
status, stitched_image = stitcher.stitch(resized_images)
if status == cv2.Stitcher_OK:
print("Stitching successful!")
return stitched_image
else:
print(f"Stitching failed with status code: {status}")
return None
# YOLO detection function
def yolo_detect(
image: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor],
classes: List[str],
) -> np.ndarray:
detector = YOLOWorld("yolov8x-worldv2.pt")
detector.set_classes(classes)
results: List[Results] = detector.predict(image)
boxes = []
for result in results:
boxes.append(
save_one_box(result.cpu().boxes.xyxy, im=result.orig_img, save=False)
)
del detector
return boxes[0]
# Function to predict and process image
def predict(image, offset_inches):
try:
drawer_img = yolo_detect(image, ["box"])
shrunked_img = make_square(shrink_bbox(drawer_img, 0.8))
return shrunked_img
except:
raise Exception("Unable to DETECT DRAWER, please take another picture with different magnification level!")
# Function to shrink bounding box
def shrink_bbox(image: np.ndarray, shrink_factor: float):
height, width = image.shape[:2]
center_x, center_y = width // 2, height // 2
new_width = int(width * shrink_factor)
new_height = int(height * shrink_factor)
x1 = max(center_x - new_width // 2, 0)
y1 = max(center_y - new_height // 2, 0)
x2 = min(center_x + new_width // 2, width)
y2 = min(center_y + new_height // 2, height)
cropped_image = image[y1:y2, x1:x2]
return cropped_image
# Function to make image square
def make_square(img: np.ndarray):
height, width = img.shape[:2]
max_dim = max(height, width)
pad_height = (max_dim - height) // 2
pad_width = (max_dim - width) // 2
pad_height_extra = max_dim - height - 2 * pad_height
pad_width_extra = max_dim - width - 2 * pad_width
if len(img.shape) == 3:
padded = np.pad(
img,
((pad_height, pad_height + pad_height_extra),
(pad_width, pad_width + pad_width_extra),
(0, 0)),
mode="edge"
)
else:
padded = np.pad(
img,
((pad_height, pad_height + pad_height_extra),
(pad_width, pad_width + pad_width_extra)),
mode="edge"
)
return padded
# Main image processing function
def process_image(image_paths, scale_percent=50, offset_inches=1):
stitched_image = stitch_images(image_paths, scale_percent)
if stitched_image is not None:
try:
stitched_image_rgb = cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
final_image = predict(stitched_image_rgb, offset_inches)
return final_image
except Exception as e:
print(str(e))
return stitched_image
# Gradio interface function
def gradio_stitch_and_detect(image_files):
image_paths = [file.name for file in image_files]
result_image = process_image(image_paths, scale_percent=50)
if result_image is not None:
result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(result_image_rgb)
pil_image.save("stitched_image.jpg", "JPEG")
return pil_image, "stitched_image.jpg"
return None, None
# Gradio interface
with gr.Blocks() as interface:
gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>")
gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>=== Upload the images you want to stitch ===</h3>")
image_upload = gr.Files(type="filepath", label="Upload Images")
stitch_button = gr.Button("Stitch", variant="primary")
stitched_image = gr.Image(type="pil", label="Stitched Image")
download_button = gr.File(label="Download Stitched Image")
stitch_button.click(gradio_stitch_and_detect, inputs=image_upload, outputs=[stitched_image, download_button])
interface.launch()
|