engajify commited on
Commit
b3e7ca1
·
verified ·
1 Parent(s): a344061

Upload 4 files

Browse files
Files changed (4) hide show
  1. README.md +4 -3
  2. app.py +225 -0
  3. gitattributes +49 -0
  4. requirements.txt +8 -0
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Image Object
3
  emoji: ⚡
4
- colorFrom: purple
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.31.5
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Specific Object Recognition In The Wild
3
  emoji: ⚡
4
+ colorFrom: red
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.42.0
8
  app_file: app.py
9
  pinned: false
10
+ license: openrail
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
5
+ from PIL import Image, ImageDraw
6
+ import cv2
7
+ import torch.nn.functional as F
8
+ import tempfile
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.cm as cm
11
+ from io import BytesIO
12
+ from SuperGluePretrainedNetwork.models.matching import Matching
13
+ from SuperGluePretrainedNetwork.models.utils import read_image
14
+
15
+ # Set device
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # Load models
19
+ model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
20
+ processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
21
+
22
+ matching = Matching({
23
+ 'superpoint': {'nms_radius': 4, 'keypoint_threshold': 0.005, 'max_keypoints': 1024},
24
+ 'superglue': {'weights': 'outdoor', 'sinkhorn_iterations': 20, 'match_threshold': 0.2}
25
+ }).eval().to(device)
26
+
27
+ # Utility functions
28
+ def save_array_to_temp_image(arr):
29
+ rgb_arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
30
+ img = Image.fromarray(rgb_arr)
31
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
32
+ temp_file_name = temp_file.name
33
+ temp_file.close()
34
+ img.save(temp_file_name)
35
+ return temp_file_name
36
+
37
+ def unified_matching_plot2(image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, text, path=None, show_keypoints=False, fast_viz=False, opencv_display=False, opencv_title='matches', small_text=[]):
38
+ height = min(image0.shape[0], image1.shape[0])
39
+ image0_resized = cv2.resize(image0, (int(image0.shape[1] * height / image0.shape[0]), height))
40
+ image1_resized = cv2.resize(image1, (int(image1.shape[1] * height / image1.shape[0]), height))
41
+
42
+ plt.figure(figsize=(15, 15))
43
+ plt.subplot(1, 2, 1)
44
+ plt.imshow(image0_resized)
45
+ plt.scatter(kpts0[:, 0], kpts0[:, 1], color='r', s=1)
46
+ plt.axis('off')
47
+
48
+ plt.subplot(1, 2, 2)
49
+ plt.imshow(image1_resized)
50
+ plt.scatter(kpts1[:, 0], kpts1[:, 1], color='r', s=1)
51
+ plt.axis('off')
52
+
53
+ fig, ax = plt.subplots(figsize=(20, 20))
54
+ plt.plot([mkpts0[:, 0], mkpts1[:, 0] + image0_resized.shape[1]], [mkpts0[:, 1], mkpts1[:, 1]], 'r', lw=0.5)
55
+ plt.scatter(mkpts0[:, 0], mkpts0[:, 1], s=2, marker='o', color='b')
56
+ plt.scatter(mkpts1[:, 0] + image0_resized.shape[1], mkpts1[:, 1], s=2, marker='o', color='g')
57
+ plt.imshow(np.hstack([image0_resized, image1_resized]), aspect='auto')
58
+
59
+ plt.suptitle('\n'.join(text), fontsize=20, fontweight='bold')
60
+ plt.tight_layout()
61
+ plt.show()
62
+
63
+ buf = BytesIO()
64
+ plt.savefig(buf, format='png')
65
+ buf.seek(0)
66
+ img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
67
+ buf.close()
68
+ img = cv2.imdecode(img_arr, 1)
69
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
70
+ plt.close(fig)
71
+
72
+ return img
73
+
74
+ def stitch_images(images):
75
+ """Stitches a list of images vertically."""
76
+ if not images:
77
+ return Image.new('RGB', (100, 100), color='gray')
78
+
79
+ max_width = max([img.width for img in images])
80
+ total_height = sum(img.height for img in images)
81
+
82
+ composite = Image.new('RGB', (max_width, total_height))
83
+
84
+ y_offset = 0
85
+ for img in images:
86
+ composite.paste(img, (0, y_offset))
87
+ y_offset += img.height
88
+
89
+ return composite
90
+
91
+ # Main functions
92
+ def detect_and_crop(target_image, query_image, threshold=0.5, nms_threshold=0.3):
93
+ target_sizes = torch.Tensor([target_image.size[::-1]])
94
+ inputs = processor(images=target_image, query_images=query_image, return_tensors="pt").to(device)
95
+ with torch.no_grad():
96
+ outputs = model.image_guided_detection(**inputs)
97
+
98
+ img = cv2.cvtColor(np.array(target_image), cv2.COLOR_BGR2RGB)
99
+ outputs.logits = outputs.logits.cpu()
100
+ outputs.target_pred_boxes = outputs.target_pred_boxes.cpu()
101
+
102
+ results = processor.post_process_image_guided_detection(outputs=outputs, threshold=threshold, nms_threshold=nms_threshold, target_sizes=target_sizes)
103
+ boxes, scores = results[0]["boxes"], results[0]["scores"]
104
+
105
+ if len(boxes) == 0:
106
+ return [], None
107
+
108
+ filtered_boxes = []
109
+ for box in boxes:
110
+ x1, y1, x2, y2 = [int(i) for i in box.tolist()]
111
+ cropped_img = img[y1:y2, x1:x2]
112
+ if cropped_img.size != 0:
113
+ filtered_boxes.append(cropped_img)
114
+
115
+ draw = ImageDraw.Draw(target_image)
116
+ for box in boxes:
117
+ draw.rectangle(box.tolist(), outline="red", width=3)
118
+
119
+ return filtered_boxes, target_image
120
+
121
+ def image_matching_no_pyramid(query_img, target_img, visualize=True):
122
+ temp_query = save_array_to_temp_image(np.array(query_img))
123
+ temp_target = save_array_to_temp_image(np.array(target_img))
124
+
125
+ image1, inp1, scales1 = read_image(temp_target, device, [640*2], 0, True)
126
+ image0, inp0, scales0 = read_image(temp_query, device, [640*2], 0, True)
127
+
128
+ if image0 is None or image1 is None:
129
+ return None
130
+
131
+ pred = matching({'image0': inp0, 'image1': inp1})
132
+ pred = {k: v[0] for k, v in pred.items()}
133
+ kpts0, kpts1 = pred['keypoints0'], pred['keypoints1']
134
+ matches, conf = pred['matches0'], pred['matching_scores0']
135
+
136
+ valid = matches > -1
137
+ mkpts0 = kpts0[valid]
138
+ mkpts1 = kpts1[matches[valid]]
139
+ mconf = conf[valid]
140
+ color = cm.jet(mconf.detach().cpu().numpy())[:len(mkpts0)]
141
+
142
+ valid_count = np.sum(valid.tolist())
143
+
144
+ mkpts0_np = mkpts0.cpu().numpy()
145
+ mkpts1_np = mkpts1.cpu().numpy()
146
+
147
+ try:
148
+ H, inliers = cv2.findHomography(mkpts0_np, mkpts1_np, cv2.RANSAC, 5.0)
149
+ except:
150
+ inliers = 0
151
+
152
+ num_inliers = np.sum(inliers)
153
+
154
+ if visualize:
155
+ visualized_img = unified_matching_plot2(
156
+ image0, image1, kpts0, kpts1, mkpts0, mkpts1, color, ['Matches'], True, False, True, 'Matches', [])
157
+ else:
158
+ visualized_img = None
159
+
160
+ return {
161
+ 'valid': [valid_count],
162
+ 'inliers': [num_inliers],
163
+ 'visualized_image': [visualized_img]
164
+ }
165
+
166
+ def check_object_in_image(query_image, target_image, threshold=50, scale_factor=[0.33, 0.66, 1]):
167
+ images_to_return = []
168
+ cropped_images, bbox_image = detect_and_crop(target_image, query_image)
169
+
170
+ temp_files = [save_array_to_temp_image(i) for i in cropped_images]
171
+ crop_results = [image_matching_no_pyramid(query_image, Image.open(i), visualize=True) for i in temp_files]
172
+
173
+ cropped_visuals = []
174
+ cropped_inliers = []
175
+ for result in crop_results:
176
+ if result:
177
+ for img in result['visualized_image']:
178
+ cropped_visuals.append(Image.fromarray(img))
179
+ for inliers_ in result['inliers']:
180
+ cropped_inliers.append(inliers_)
181
+
182
+ images_to_return.append(stitch_images(cropped_visuals))
183
+
184
+ is_present = any(value >= threshold for value in cropped_inliers)
185
+
186
+ return {
187
+ 'is_present': is_present,
188
+ 'image_with_boxes': bbox_image,
189
+ 'object_detection_inliers': [int(i) for i in cropped_inliers],
190
+ }
191
+
192
+ def interface(poster_source, media_source, threshold, scale_factor):
193
+ result1 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
194
+ if result1['is_present']:
195
+ return result1['is_present'], result1['image_with_boxes']
196
+
197
+ result2 = check_object_in_image(poster_source, media_source, threshold, scale_factor)
198
+ return result2['is_present'], result2['image_with_boxes']
199
+
200
+ iface = gr.Interface(
201
+ fn=interface,
202
+ inputs=[
203
+ gr.Image(type="pil", label="Upload a Query Image (Poster)"),
204
+ gr.Image(type="pil", label="Upload a Target Image (Media)"),
205
+ gr.Slider(minimum=0, maximum=100, step=1, value=50, label="Threshold"),
206
+ gr.CheckboxGroup(choices=["0.33", "0.66", "1.0"], value=["0.33", "0.66", "1.0"], label="Scale Factors"),
207
+ ],
208
+ outputs=[
209
+ gr.Label(label="Object Presence"),
210
+ gr.Image(type="pil", label="Detected Bounding Boxes"),
211
+ ],
212
+ title="Object Detection in Images",
213
+ description="""
214
+ This application allows you to check if an object in a query image (poster) is present in a target image (media).
215
+ Steps:
216
+ 1. Upload a Query Image (Poster)
217
+ 2. Upload a Target Image (Media)
218
+ 3. Set Threshold
219
+ 4. Set Scale Factors
220
+ 5. View Results
221
+ """
222
+ )
223
+
224
+ if __name__ == "__main__":
225
+ iface.launch()
gitattributes ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ SuperGluePretrainedNetwork/assets/freiburg_matches.gif filter=lfs diff=lfs merge=lfs -text
37
+ Samples/Images/Frame[[:space:]]13.png filter=lfs diff=lfs merge=lfs -text
38
+ Samples/Images/Frame[[:space:]]14.png filter=lfs diff=lfs merge=lfs -text
39
+ Samples/Images/Frame[[:space:]]3.png filter=lfs diff=lfs merge=lfs -text
40
+ Samples/Images/Frame[[:space:]]5.png filter=lfs diff=lfs merge=lfs -text
41
+ Samples/Images/Frame[[:space:]]6.png filter=lfs diff=lfs merge=lfs -text
42
+ Samples/Images/sub-buzz-16648-1550412154-1.png filter=lfs diff=lfs merge=lfs -text
43
+ Samples/Images/Test3.png filter=lfs diff=lfs merge=lfs -text
44
+ Samples/Poster/ForScale.png filter=lfs diff=lfs merge=lfs -text
45
+ Samples/Poster/ForScaleG.png filter=lfs diff=lfs merge=lfs -text
46
+ Samples/Poster/ForScaleM.png filter=lfs diff=lfs merge=lfs -text
47
+ Samples/Poster/ForScaleMM.png filter=lfs diff=lfs merge=lfs -text
48
+ Samples/Poster/poster_event_small.png filter=lfs diff=lfs merge=lfs -text
49
+ Samples/Poster/poster_event.png filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ torchvision
5
+ numpy
6
+ Pillow
7
+ opencv-python
8
+ matplotlib