ammariii08 commited on
Commit
77e0dca
·
verified ·
1 Parent(s): c46b861

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -29
app.py CHANGED
@@ -1,35 +1,153 @@
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
- from stitching import Stitcher
 
5
  from PIL import Image
 
 
 
 
6
 
7
- # Function to stitch uploaded images
8
- def stitch_images(images):
9
- stitcher = Stitcher()
10
- panorama = stitcher.stitch(images) # Stitch the images
11
-
12
- # Convert the result to a PIL image
13
- pil_image = Image.fromarray(cv2.cvtColor(panorama, cv2.COLOR_BGR2RGB))
14
-
15
- # Save the image in JPG format
16
- pil_image.save("stitched_image.jpg", "JPEG")
17
-
18
- # Return the stitched image
19
- return pil_image, "stitched_image.jpg" # Return the image and the file path
20
-
21
- # Create Gradio interface with button
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  with gr.Blocks() as interface:
23
- gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>") # Heading with a new color
24
- gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>=== Upload the images you want to stitch ===</h3>") # Subheading
25
-
26
- image_upload = gr.Files(type="filepath", label="Upload Images") # File input for images
27
- stitch_button = gr.Button("Stitch", variant="primary") # Button with color
28
- stitched_image = gr.Image(type="pil", label="Stitched Image") # Display stitched image
29
- download_button = gr.File(label="Download Stitched Image") # Button to download the image
30
-
31
- # Define button interaction
32
- stitch_button.click(stitch_images, inputs=image_upload, outputs=[stitched_image, download_button])
33
-
34
- # Launch the interface
35
- interface.launch()
 
1
  import gradio as gr
2
  import cv2
3
  import numpy as np
4
+ from typing import Union, List
5
+ from pathlib import Path
6
  from PIL import Image
7
+ import torch
8
+ from ultralytics.utils.plotting import save_one_box
9
+ from ultralytics.engine.results import Results
10
+ from ultralytics import YOLOWorld
11
 
12
+ # Function to resize images
13
+
14
+ def resize_images(images, scale_percent=50):
15
+ resized_images = []
16
+ for img in images:
17
+ width = int(img.shape[1] * scale_percent / 100)
18
+ height = int(img.shape[0] * scale_percent / 100)
19
+ dim = (width, height)
20
+ resized = cv2.resize(img, dim, interpolation=cv2.INTER_AREA)
21
+ resized_images.append(resized)
22
+ return resized_images
23
+
24
+ # Function to stitch images
25
+
26
+ def stitch_images(image_paths, scale_percent=50):
27
+ images = [cv2.imread(path) for path in image_paths]
28
+ resized_images = resize_images(images, scale_percent)
29
+ stitcher = cv2.Stitcher_create()
30
+ status, stitched_image = stitcher.stitch(resized_images)
31
+
32
+ if status == cv2.Stitcher_OK:
33
+ print("Stitching successful!")
34
+ return stitched_image
35
+ else:
36
+ print(f"Stitching failed with status code: {status}")
37
+ return None
38
+
39
+ # YOLO detection function
40
+
41
+ def yolo_detect(
42
+ image: Union[str, Path, int, Image.Image, list, tuple, np.ndarray, torch.Tensor],
43
+ classes: List[str],
44
+ ) -> np.ndarray:
45
+ detector = YOLOWorld("yolov8x-worldv2.pt")
46
+ detector.set_classes(classes)
47
+ results: List[Results] = detector.predict(image)
48
+
49
+ boxes = []
50
+ for result in results:
51
+ boxes.append(
52
+ save_one_box(result.cpu().boxes.xyxy, im=result.orig_img, save=False)
53
+ )
54
+
55
+ del detector
56
+ return boxes[0]
57
+
58
+ # Function to predict and process image
59
+
60
+ def predict(image, offset_inches):
61
+ try:
62
+ drawer_img = yolo_detect(image, ["box"])
63
+ shrunked_img = make_square(shrink_bbox(drawer_img, 0.8))
64
+ return shrunked_img
65
+ except:
66
+ raise Exception("Unable to DETECT DRAWER, please take another picture with different magnification level!")
67
+
68
+ # Function to shrink bounding box
69
+
70
+ def shrink_bbox(image: np.ndarray, shrink_factor: float):
71
+ height, width = image.shape[:2]
72
+ center_x, center_y = width // 2, height // 2
73
+
74
+ new_width = int(width * shrink_factor)
75
+ new_height = int(height * shrink_factor)
76
+
77
+ x1 = max(center_x - new_width // 2, 0)
78
+ y1 = max(center_y - new_height // 2, 0)
79
+ x2 = min(center_x + new_width // 2, width)
80
+ y2 = min(center_y + new_height // 2, height)
81
+
82
+ cropped_image = image[y1:y2, x1:x2]
83
+ return cropped_image
84
+
85
+ # Function to make image square
86
+
87
+ def make_square(img: np.ndarray):
88
+ height, width = img.shape[:2]
89
+ max_dim = max(height, width)
90
+ pad_height = (max_dim - height) // 2
91
+ pad_width = (max_dim - width) // 2
92
+ pad_height_extra = max_dim - height - 2 * pad_height
93
+ pad_width_extra = max_dim - width - 2 * pad_width
94
+
95
+ if len(img.shape) == 3:
96
+ padded = np.pad(
97
+ img,
98
+ ((pad_height, pad_height + pad_height_extra),
99
+ (pad_width, pad_width + pad_width_extra),
100
+ (0, 0)),
101
+ mode="edge"
102
+ )
103
+ else:
104
+ padded = np.pad(
105
+ img,
106
+ ((pad_height, pad_height + pad_height_extra),
107
+ (pad_width, pad_width + pad_width_extra)),
108
+ mode="edge"
109
+ )
110
+
111
+ return padded
112
+
113
+ # Main image processing function
114
+
115
+ def process_image(image_paths, scale_percent=50, offset_inches=1):
116
+ stitched_image = stitch_images(image_paths, scale_percent)
117
+
118
+ if stitched_image is not None:
119
+ try:
120
+ stitched_image_rgb = cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
121
+ final_image = predict(stitched_image_rgb, offset_inches)
122
+ return final_image
123
+ except Exception as e:
124
+ print(str(e))
125
+ return stitched_image
126
+
127
+ # Gradio interface function
128
+
129
+ def gradio_stitch_and_detect(image_files):
130
+ image_paths = [file.name for file in image_files]
131
+ result_image = process_image(image_paths, scale_percent=50)
132
+
133
+ if result_image is not None:
134
+ result_image_rgb = cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB)
135
+ pil_image = Image.fromarray(result_image_rgb)
136
+ pil_image.save("stitched_image.jpg", "JPEG")
137
+ return pil_image, "stitched_image.jpg"
138
+
139
+ return None, None
140
+
141
+ # Gradio interface
142
  with gr.Blocks() as interface:
143
+ gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>")
144
+ gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>=== Upload the images you want to stitch ===</h3>")
145
+
146
+ image_upload = gr.Files(type="filepath", label="Upload Images")
147
+ stitch_button = gr.Button("Stitch", variant="primary")
148
+ stitched_image = gr.Image(type="pil", label="Stitched Image")
149
+ download_button = gr.File(label="Download Stitched Image")
150
+
151
+ stitch_button.click(gradio_stitch_and_detect, inputs=image_upload, outputs=[stitched_image, download_button])
152
+
153
+ interface.launch()