dragonSwing commited on
Commit
699405e
·
1 Parent(s): eeef127

Revert app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -279
app.py CHANGED
@@ -1,47 +1,96 @@
1
- import argparse
2
  import json
3
  import os
 
4
  import sys
5
  import tempfile
6
 
 
7
  import numpy as np
8
  import supervision as sv
9
- from groundingdino.util.inference import Model as DinoModel
10
- from imutils import paths
11
  from PIL import Image
12
- from segment_anything import sam_model_registry
13
  from segment_anything import SamAutomaticMaskGenerator
14
  from segment_anything import SamPredictor
 
15
  from supervision.detection.utils import xywh_to_xyxy
16
- from tqdm import tqdm
 
 
 
17
 
18
  sys.path.append("tag2text")
 
19
 
 
20
  from tag2text.models import tag2text
21
  from config import *
22
- from utils import detect, download_file_hf, segment, generate_tags, show_anns_sv
23
-
24
-
25
- def process(
26
- tag2text_model,
27
- grounding_dino_model,
28
- sam_predictor,
29
- sam_automask_generator,
30
- image_path,
31
- task,
32
- prompt,
33
- box_threshold,
34
- text_threshold,
35
- iou_threshold,
36
- device,
37
- output_dir=None,
38
- save_mask=False,
39
- ):
40
- detections = None
41
- metadata = {"image": {}, "annotations": [], "assets": {}}
42
 
43
- if save_mask:
44
- metadata["assets"]["intermediate_mask"] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  try:
47
  # Load image
@@ -51,18 +100,17 @@ def process(
51
 
52
  # Extract image metadata
53
  filename = os.path.basename(image_path)
54
- basename = os.path.splitext(filename)[0]
55
  h, w = image.shape[:2]
56
  metadata["image"]["file_name"] = filename
57
  metadata["image"]["width"] = w
58
  metadata["image"]["height"] = h
59
 
60
  # Generate tags
61
- if task in ["auto", "detection"] and prompt == "":
62
  tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
63
  prompt = " . ".join(tags)
64
- # print(f"Caption: {caption}")
65
- # print(f"Tags: {tags}")
66
 
67
  # ToDo: Extract metadata
68
  metadata["image"]["caption"] = caption
@@ -70,6 +118,7 @@ def process(
70
 
71
  if prompt:
72
  metadata["prompt"] = prompt
 
73
 
74
  # Detect boxes
75
  if prompt != "":
@@ -82,21 +131,18 @@ def process(
82
  iou_threshold=iou_threshold,
83
  post_process=True,
84
  )
85
-
86
- # Save detection image
87
- if output_dir:
88
- # Draw boxes
89
- box_annotator = sv.BoxAnnotator()
90
- labels = [
91
- f"{phrases[i]} {detections.confidence[i]:0.2f}"
92
- for i in range(len(phrases))
93
- ]
94
- box_image = box_annotator.annotate(
95
- scene=image, detections=detections, labels=labels
96
- )
97
- box_image_path = os.path.join(output_dir, basename + "_detect.png")
98
- metadata["assets"]["detection"] = box_image_path
99
- Image.fromarray(box_image).save(box_image_path)
100
 
101
  # Segmentation
102
  if task in ["auto", "segment"]:
@@ -121,27 +167,18 @@ def process(
121
  detections = sv.Detections(
122
  xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
123
  )
124
-
125
- # Save annotated image
126
- if output_dir:
127
- mask_annotator = sv.MaskAnnotator()
128
- mask_image, res = show_anns_sv(detections)
129
- annotated_image = mask_annotator.annotate(image, detections=detections)
130
-
131
- mask_image_path = os.path.join(output_dir, basename + "_mask.png")
132
- metadata["assets"]["mask"] = mask_image_path
133
- Image.fromarray(mask_image).save(mask_image_path)
134
-
135
- # Save annotation encoding from https://github.com/LUSSeg/ImageNet-S
136
- mask_enc_path = os.path.join(output_dir, basename + "_mask_enc.npy")
137
- np.save(mask_enc_path, res)
138
- metadata["assets"]["mask_enc"] = mask_enc_path
139
-
140
- annotated_image_path = os.path.join(
141
- output_dir, basename + "_annotate.png"
142
- )
143
- metadata["assets"]["annotate"] = annotated_image_path
144
- Image.fromarray(annotated_image).save(annotated_image_path)
145
 
146
  # ToDo: Extract metadata
147
  if detections:
@@ -164,222 +201,86 @@ def process(
164
  metadata["annotations"].append(annotation)
165
  i += 1
166
 
167
- if output_dir and save_mask:
168
- mask_image_path = os.path.join(
169
- output_dir, f"{basename}_mask_{id}.png"
170
- )
171
- metadata["assets"]["intermediate_mask"].append(mask_image_path)
172
- Image.fromarray(mask * 255).save(mask_image_path)
173
-
174
- if output_dir:
175
- meta_file_path = os.path.join(output_dir, basename + "_meta.json")
176
- with open(meta_file_path, "w") as fp:
177
- json.dump(metadata, fp)
178
- else:
179
- meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
180
- meta_file_path = meta_file.name
181
-
182
- return meta_file_path
183
- except Exception as error:
184
- raise ValueError(f"global exception: {error}")
185
-
186
-
187
- def main(args: argparse.Namespace) -> None:
188
- device = args.device
189
- prompt = args.prompt
190
- task = args.task
191
-
192
- tag2text_model = None
193
- grounding_dino_model = None
194
- sam_predictor = None
195
- sam_automask_generator = None
196
-
197
- box_threshold = args.box_threshold
198
- text_threshold = args.text_threshold
199
- iou_threshold = args.iou_threshold
200
- save_mask = args.save_mask
201
-
202
- # load model
203
- if task in ["auto", "detection"] and prompt == "":
204
- print("Loading Tag2Text model...")
205
- tag2text_type = args.tag2text_type
206
- tag2text_checkpoint = os.path.join(
207
- abs_weight_dir, tag2text_dict[tag2text_type]["checkpoint_file"]
208
- )
209
- if not os.path.exists(tag2text_checkpoint):
210
- print(f"Downloading weights for Tag2Text {tag2text_type} model")
211
- os.system(
212
- f"wget {tag2text_dict[tag2text_type]['checkpoint_url']} -O {tag2text_checkpoint}"
213
- )
214
- tag2text_model = tag2text.tag2text_caption(
215
- pretrained=tag2text_checkpoint,
216
- image_size=384,
217
- vit="swin_b",
218
- delete_tag_index=delete_tag_index,
219
- )
220
- # threshold for tagging
221
- # we reduce the threshold to obtain more tags
222
- tag2text_model.threshold = 0.64
223
- tag2text_model.to(device)
224
- tag2text_model.eval()
225
-
226
- if task in ["auto", "detection"] or prompt != "":
227
- print("Loading Grounding Dino model...")
228
- dino_type = args.dino_type
229
- dino_checkpoint = os.path.join(
230
- abs_weight_dir, dino_dict[dino_type]["checkpoint_file"]
231
- )
232
- dino_config_file = os.path.join(
233
- abs_weight_dir, dino_dict[dino_type]["config_file"]
234
- )
235
- if not os.path.exists(dino_checkpoint):
236
- print(f"Downloading weights for Grounding Dino {dino_type} model")
237
- dino_repo_id = dino_dict[dino_type]["repo_id"]
238
- download_file_hf(
239
- repo_id=dino_repo_id,
240
- filename=dino_dict[dino_type]["checkpoint_file"],
241
- cache_dir=weight_dir,
242
- )
243
- download_file_hf(
244
- repo_id=dino_repo_id,
245
- filename=dino_dict[dino_type]["checkpoint_file"],
246
- cache_dir=weight_dir,
247
- )
248
- grounding_dino_model = DinoModel(
249
- model_config_path=dino_config_file,
250
- model_checkpoint_path=dino_checkpoint,
251
- device=device,
252
- )
253
 
254
- if task in ["auto", "segment"]:
255
- print("Loading SAM...")
256
- sam_type = args.sam_type
257
- sam_checkpoint = os.path.join(
258
- abs_weight_dir, sam_dict[sam_type]["checkpoint_file"]
259
- )
260
- if not os.path.exists(sam_checkpoint):
261
- print(f"Downloading weights for SAM {sam_type}")
262
- os.system(
263
- f"wget {sam_dict[sam_type]['checkpoint_url']} -O {sam_checkpoint}"
264
- )
265
- sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
266
- sam.to(device=device)
267
- sam_predictor = SamPredictor(sam)
268
- sam_automask_generator = SamAutomaticMaskGenerator(sam)
269
-
270
- if not os.path.exists(args.input):
271
- raise ValueError("The input directory doesn't exist!")
272
- elif not os.path.isdir(args.input):
273
- image_paths = [args.input]
274
- else:
275
- image_paths = paths.list_images(args.input)
276
-
277
- os.makedirs(args.output, exist_ok=True)
278
-
279
- with tqdm(image_paths) as pbar:
280
- for image_path in pbar:
281
- pbar.set_postfix_str(f"Processing {image_path}")
282
- process(
283
- tag2text_model=tag2text_model,
284
- grounding_dino_model=grounding_dino_model,
285
- sam_predictor=sam_predictor,
286
- sam_automask_generator=sam_automask_generator,
287
- image_path=image_path,
288
- task=task,
289
- prompt=prompt,
290
- box_threshold=box_threshold,
291
- text_threshold=text_threshold,
292
- iou_threshold=iou_threshold,
293
- device=device,
294
- output_dir=args.output,
295
- save_mask=save_mask,
296
- )
297
 
298
 
299
- if __name__ == "__main__":
300
- if not os.path.exists(abs_weight_dir):
301
- os.makedirs(abs_weight_dir, exist_ok=True)
302
 
303
- parser = argparse.ArgumentParser(
304
- description=(
305
- "Runs automatic detection and mask generation on an input image or directory of images"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  )
 
 
 
 
 
 
 
 
 
 
 
307
  )
308
 
309
- parser.add_argument(
310
- "--input",
311
- "-i",
312
- type=str,
313
- required=True,
314
- help="Path to either a single input image or folder of images.",
315
- )
316
-
317
- parser.add_argument(
318
- "--output",
319
- "-o",
320
- type=str,
321
- required=True,
322
- help=(
323
- "Path to the directory where masks will be output."
324
- ),
325
- )
326
-
327
- parser.add_argument(
328
- "--sam-type",
329
- type=str,
330
- default=default_sam,
331
- choices=sam_dict.keys(),
332
- help="The type of SA model use for segmentation.",
333
- )
334
-
335
- parser.add_argument(
336
- "--tag2text-type",
337
- type=str,
338
- default=default_tag2text,
339
- choices=tag2text_dict.keys(),
340
- help="The type of Tag2Text model use for tags and caption generation.",
341
- )
342
-
343
- parser.add_argument(
344
- "--dino-type",
345
- type=str,
346
- default=default_dino,
347
- choices=dino_dict.keys(),
348
- help="The type of Grounding Dino model use for promptable object detection.",
349
- )
350
-
351
- parser.add_argument(
352
- "--task",
353
- help="Task to run",
354
- default="auto",
355
- choices=["auto", "detect", "segment"],
356
- type=str,
357
- )
358
- parser.add_argument(
359
- "--prompt",
360
- help="Detection prompt",
361
- default="",
362
- type=str,
363
- )
364
-
365
- parser.add_argument(
366
- "--box-threshold", type=float, default=0.25, help="box threshold"
367
- )
368
- parser.add_argument(
369
- "--text-threshold", type=float, default=0.2, help="text threshold"
370
- )
371
- parser.add_argument(
372
- "--iou-threshold", type=float, default=0.5, help="iou threshold"
373
- )
374
-
375
- parser.add_argument(
376
- "--save-mask",
377
- action="store_true",
378
- default=False,
379
- help="If True, save all intermidiate masks.",
380
- )
381
- parser.add_argument(
382
- "--device", type=str, default="cuda", help="The device to run generation on."
383
- )
384
- args = parser.parse_args()
385
- main(args)
 
 
1
  import json
2
  import os
3
+ import subprocess
4
  import sys
5
  import tempfile
6
 
7
+ import gradio as gr
8
  import numpy as np
9
  import supervision as sv
10
+ import torch
 
11
  from PIL import Image
12
+ from segment_anything import build_sam
13
  from segment_anything import SamAutomaticMaskGenerator
14
  from segment_anything import SamPredictor
15
+ from supervision.detection.utils import mask_to_polygons
16
  from supervision.detection.utils import xywh_to_xyxy
17
+
18
+ if os.environ.get('IS_MY_DEBUG') is None:
19
+ result = subprocess.run(['pip', 'install', '-e', 'GroundingDINO'], check=True)
20
+ print(f'pip install GroundingDINO = {result}')
21
 
22
  sys.path.append("tag2text")
23
+ sys.path.append("GroundingDINO")
24
 
25
+ from groundingdino.util.inference import Model as DinoModel
26
  from tag2text.models import tag2text
27
  from config import *
28
+ from utils import download_file_hf, detect, segment, show_anns, generate_tags
29
+
30
+ if not os.path.exists(abs_weight_dir):
31
+ os.makedirs(abs_weight_dir, exist_ok=True)
32
+
33
+ sam_checkpoint = os.path.join(abs_weight_dir, sam_dict[default_sam]["checkpoint_file"])
34
+ if not os.path.exists(sam_checkpoint):
35
+ os.system(f"wget {sam_dict[default_sam]['checkpoint_url']} -O {sam_checkpoint}")
36
+
37
+ tag2text_checkpoint = os.path.join(
38
+ abs_weight_dir, tag2text_dict[default_tag2text]["checkpoint_file"]
39
+ )
40
+ if not os.path.exists(tag2text_checkpoint):
41
+ os.system(
42
+ f"wget {tag2text_dict[default_tag2text]['checkpoint_url']} -O {tag2text_checkpoint}"
43
+ )
 
 
 
 
44
 
45
+ dino_checkpoint = os.path.join(
46
+ abs_weight_dir, dino_dict[default_dino]["checkpoint_file"]
47
+ )
48
+ dino_config_file = os.path.join(abs_weight_dir, dino_dict[default_dino]["config_file"])
49
+ if not os.path.exists(dino_checkpoint):
50
+ dino_repo_id = dino_dict[default_dino]["repo_id"]
51
+ download_file_hf(
52
+ repo_id=dino_repo_id,
53
+ filename=dino_dict[default_dino]["config_file"],
54
+ cache_dir=weight_dir,
55
+ )
56
+ download_file_hf(
57
+ repo_id=dino_repo_id,
58
+ filename=dino_dict[default_dino]["checkpoint_file"],
59
+ cache_dir=weight_dir,
60
+ )
61
+
62
+ # load model
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ tag2text_model = tag2text.tag2text_caption(
65
+ pretrained=tag2text_checkpoint,
66
+ image_size=384,
67
+ vit="swin_b",
68
+ delete_tag_index=delete_tag_index,
69
+ )
70
+ # threshold for tagging
71
+ # we reduce the threshold to obtain more tags
72
+ tag2text_model.threshold = 0.64
73
+ tag2text_model.to(device)
74
+ tag2text_model.eval()
75
+
76
+
77
+ sam = build_sam(checkpoint=sam_checkpoint)
78
+ sam.to(device=device)
79
+ sam_predictor = SamPredictor(sam)
80
+ sam_automask_generator = SamAutomaticMaskGenerator(sam)
81
+
82
+ grounding_dino_model = DinoModel(
83
+ model_config_path=dino_config_file,
84
+ model_checkpoint_path=dino_checkpoint,
85
+ device=device,
86
+ )
87
+
88
+
89
+ def process(image_path, task, prompt, box_threshold, text_threshold, iou_threshold):
90
+ global tag2text_model, sam_predictor, sam_automask_generator, grounding_dino_model, device
91
+ output_gallery = []
92
+ detections = None
93
+ metadata = {"image": {}, "annotations": []}
94
 
95
  try:
96
  # Load image
 
100
 
101
  # Extract image metadata
102
  filename = os.path.basename(image_path)
 
103
  h, w = image.shape[:2]
104
  metadata["image"]["file_name"] = filename
105
  metadata["image"]["width"] = w
106
  metadata["image"]["height"] = h
107
 
108
  # Generate tags
109
+ if task in ["auto", "detect"] and prompt == "":
110
  tags, caption = generate_tags(tag2text_model, image_pil, "None", device)
111
  prompt = " . ".join(tags)
112
+ print(f"Caption: {caption}")
113
+ print(f"Tags: {tags}")
114
 
115
  # ToDo: Extract metadata
116
  metadata["image"]["caption"] = caption
 
118
 
119
  if prompt:
120
  metadata["prompt"] = prompt
121
+ print(f"Prompt: {prompt}")
122
 
123
  # Detect boxes
124
  if prompt != "":
 
131
  iou_threshold=iou_threshold,
132
  post_process=True,
133
  )
134
+ print(phrases)
135
+
136
+ # Draw boxes
137
+ box_annotator = sv.BoxAnnotator()
138
+ labels = [
139
+ f"{phrases[i]} {detections.confidence[i]:0.2f}"
140
+ for i in range(len(phrases))
141
+ ]
142
+ image = box_annotator.annotate(
143
+ scene=image, detections=detections, labels=labels
144
+ )
145
+ output_gallery.append(image)
 
 
 
146
 
147
  # Segmentation
148
  if task in ["auto", "segment"]:
 
167
  detections = sv.Detections(
168
  xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask
169
  )
170
+ # opacity = 0.4
171
+ # mask_image, _ = show_anns_sam(masks)
172
+ # annotated_image = np.uint8(mask_image * opacity + image * (1 - opacity))
173
+
174
+ mask_annotator = sv.MaskAnnotator()
175
+ mask_image = np.zeros_like(image, dtype=np.uint8)
176
+ mask_image = mask_annotator.annotate(
177
+ mask_image, detections=detections, opacity=1
178
+ )
179
+ annotated_image = mask_annotator.annotate(image, detections=detections)
180
+ output_gallery.append(mask_image)
181
+ output_gallery.append(annotated_image)
 
 
 
 
 
 
 
 
 
182
 
183
  # ToDo: Extract metadata
184
  if detections:
 
201
  metadata["annotations"].append(annotation)
202
  i += 1
203
 
204
+ meta_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json")
205
+ meta_file_path = meta_file.name
206
+ with open(meta_file_path, "w") as fp:
207
+ json.dump(metadata, fp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
+ return output_gallery, meta_file_path
210
+ except Exception as error:
211
+ raise gr.Error(f"global exception: {error}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
 
214
+ title = "Annotate Anything"
 
 
215
 
216
+ with gr.Blocks(css="style.css", title=title) as demo:
217
+ with gr.Row(elem_classes=["container"]):
218
+ with gr.Column(scale=1):
219
+ input_image = gr.Image(type="filepath", label="Input")
220
+ task = gr.Dropdown(
221
+ ["detect", "segment", "auto"], value="auto", label="task_type"
222
+ )
223
+ text_prompt = gr.Textbox(
224
+ label="Detection Prompt",
225
+ info="To detect multiple objects, seperating each name with '.', like this: cat . dog . chair ",
226
+ )
227
+ with gr.Accordion("Advanced parameters", open=False):
228
+ box_threshold = gr.Slider(
229
+ minimum=0,
230
+ maximum=1,
231
+ value=0.3,
232
+ step=0.05,
233
+ label="Box threshold",
234
+ info="Hash size to use for image hashing",
235
+ )
236
+ text_threshold = gr.Slider(
237
+ minimum=0,
238
+ maximum=1,
239
+ value=0.25,
240
+ step=0.05,
241
+ label="Text threshold",
242
+ info="Number of history images used to find out duplicate image",
243
+ )
244
+ iou_threshold = gr.Slider(
245
+ minimum=0,
246
+ maximum=1,
247
+ value=0.5,
248
+ step=0.05,
249
+ label="IOU threshold",
250
+ info="Minimum similarity threshold (in percent) to consider 2 images to be similar",
251
+ )
252
+ run_button = gr.Button(label="Run")
253
+
254
+ with gr.Column(scale=2):
255
+ gallery = gr.Gallery(
256
+ label="Generated images", show_label=False, elem_id="gallery"
257
+ ).style(preview=True, grid=2, object_fit="scale-down")
258
+ meta_file = gr.File(label="Metadata file")
259
+
260
+ with gr.Row(elem_classes=["container"]):
261
+ gr.Examples(
262
+ [
263
+ ["examples/dog.png", "auto", ""],
264
+ ["examples/eiffel.png", "auto", ""],
265
+ ["examples/eiffel.png", "segment", ""],
266
+ ["examples/girl.png", "auto", "girl . face"],
267
+ ["examples/horse.png", "detect", "horse"],
268
+ ["examples/horses.jpg", "auto", "horse"],
269
+ ["examples/traffic.jpg", "auto", ""],
270
+ ],
271
+ [input_image, task, text_prompt],
272
  )
273
+ run_button.click(
274
+ fn=process,
275
+ inputs=[
276
+ input_image,
277
+ task,
278
+ text_prompt,
279
+ box_threshold,
280
+ text_threshold,
281
+ iou_threshold,
282
+ ],
283
+ outputs=[gallery, meta_file],
284
  )
285
 
286
+ demo.queue(concurrency_count=2).launch()