File size: 4,934 Bytes
8e8e66f
 
 
77e0dca
 
8e8e66f
77e0dca
 
 
 
8e8e66f
77e0dca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e8e66f
77e0dca
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import cv2
import numpy as np
from typing import Union, List
from pathlib import Path
from PIL import Image
import torch
from ultralytics.utils.plotting import save_one_box
from ultralytics.engine.results import Results
from ultralytics import YOLOWorld

# Function to resize images

def resize_images(images, scale_percent=50):
    resized_images = []
    for img in images:
        width = int(img.shape[1] * scale_percent / 100)
        height = int(img.shape[0] * scale_percent / 100)
        dim = (width, height)
        resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
        resized_images.append(resized)
    return resized_images

# Function to stitch images

def stitch_images(image_paths, scale_percent=50):
    images = [cv2.imread(path) for path in image_paths]
    resized_images = resize_images(images, scale_percent)
    stitcher = cv2.Stitcher_create()
    status, stitched_image = stitcher.stitch(resized_images)

    if status == cv2.Stitcher_OK:
        print("Stitching successful!")
        return stitched_image
    else:
        print(f"Stitching failed with status code: {status}")
        return None

# YOLO detection function

def yolo_detect(
    image: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor],
    classes: List[str],
) -> np.ndarray:
    detector = YOLOWorld("yolov8x-worldv2.pt")
    detector.set_classes(classes)
    results: List[Results] = detector.predict(image)

    boxes = []
    for result in results:
        boxes.append(
            save_one_box(result.cpu().boxes.xyxy, im=result.orig_img, save=False)
        )

    del detector
    return boxes[0]

# Function to predict and process image

def predict(image, offset_inches):
    try:
        drawer_img = yolo_detect(image, ["box"])
        shrunked_img = make_square(shrink_bbox(drawer_img, 0.8))
        return shrunked_img
    except:
        raise Exception("Unable to DETECT DRAWER, please take another picture with different magnification level!")

# Function to shrink bounding box

def shrink_bbox(image: np.ndarray, shrink_factor: float):
    height, width = image.shape[:2]
    center_x, center_y = width // 2, height // 2

    new_width = int(width * shrink_factor)
    new_height = int(height * shrink_factor)

    x1 = max(center_x - new_width // 2, 0)
    y1 = max(center_y - new_height // 2, 0)
    x2 = min(center_x + new_width // 2, width)
    y2 = min(center_y + new_height // 2, height)

    cropped_image = image[y1:y2, x1:x2]
    return cropped_image

# Function to make image square

def make_square(img: np.ndarray):
    height, width = img.shape[:2]
    max_dim = max(height, width)
    pad_height = (max_dim - height) // 2
    pad_width = (max_dim - width) // 2
    pad_height_extra = max_dim - height - 2 * pad_height
    pad_width_extra = max_dim - width - 2 * pad_width

    if len(img.shape) == 3:
        padded = np.pad(
            img,
            ((pad_height, pad_height + pad_height_extra),
             (pad_width, pad_width + pad_width_extra),
             (0, 0)),
            mode="edge"
        )
    else:
        padded = np.pad(
            img,
            ((pad_height, pad_height + pad_height_extra),
             (pad_width, pad_width + pad_width_extra)),
            mode="edge"
        )

    return padded

# Main image processing function

def process_image(image_paths, scale_percent=50, offset_inches=1):
    stitched_image = stitch_images(image_paths, scale_percent)

    if stitched_image is not None:
        try:
            stitched_image_rgb = cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
            final_image = predict(stitched_image_rgb, offset_inches)
            return final_image
        except Exception as e:
            print(str(e))
            return stitched_image

# Gradio interface function

def gradio_stitch_and_detect(image_files):
    image_paths = [file.name for file in image_files]
    result_image = process_image(image_paths, scale_percent=50)

    if result_image is not None:
        result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(result_image_rgb)
        pil_image.save("stitched_image.jpg", "JPEG")
        return pil_image, "stitched_image.jpg"

    return None, None

# Gradio interface
with gr.Blocks() as interface:
    gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>")
    gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>=== Upload the images you want to stitch ===</h3>")

    image_upload = gr.Files(type="filepath", label="Upload Images")
    stitch_button = gr.Button("Stitch", variant="primary")
    stitched_image = gr.Image(type="pil", label="Stitched Image")
    download_button = gr.File(label="Download Stitched Image")

    stitch_button.click(gradio_stitch_and_detect, inputs=image_upload, outputs=[stitched_image, download_button])

interface.launch()