Ashoka74 commited on
Commit
9615931
Β·
verified Β·
1 Parent(s): efd2136

Update gradio_demo.py

Browse files
Files changed (1) hide show
  1. gradio_demo.py +69 -85
gradio_demo.py CHANGED
@@ -854,101 +854,85 @@ def process_image(input_image, input_text):
854
  image_url = client.upload_file(tmpfile.name)
855
  os.remove(tmpfile.name)
856
 
857
- # Run DINO-X detection
858
- task = DinoxTask(
859
- image_url=image_url,
860
- prompts=[TextPrompt(text=input_text)]
861
- )
862
- client.run_task(task)
863
- result = task.result
864
- objects = result.objects
865
-
866
  # Process detection results
867
  input_boxes = []
 
868
  confidences = []
869
  class_names = []
870
  class_ids = []
871
 
872
- for obj in objects:
873
- input_boxes.append(obj.bbox)
874
- confidences.append(obj.score)
875
- cls_name = obj.category.lower().strip()
876
- class_names.append(cls_name)
877
- class_ids.append(class_name_to_id[cls_name])
878
-
879
- input_boxes = np.array(input_boxes)
880
- class_ids = np.array(class_ids)
881
-
882
- # Initialize SAM2
883
- torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
884
- if torch.cuda.get_device_properties(0).major >= 8:
885
- torch.backends.cuda.matmul.allow_tf32 = True
886
- torch.backends.cudnn.allow_tf32 = True
887
-
888
- sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
889
- sam2_predictor = SAM2ImagePredictor(sam2_model)
890
- sam2_predictor.set_image(input_image)
891
-
892
- # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
893
-
894
-
895
- # Get masks from SAM2
896
- masks, scores, logits = sam2_predictor.predict(
897
- point_coords=None,
898
- point_labels=None,
899
- box=input_boxes,
900
- multimask_output=False,
901
- )
902
- if masks.ndim == 4:
903
- masks = masks.squeeze(1)
 
 
 
904
 
905
- # Create visualization
906
- labels = [f"{class_name} {confidence:.2f}"
907
- for class_name, confidence in zip(class_names, confidences)]
908
 
909
- detections = sv.Detections(
910
- xyxy=input_boxes,
911
- mask=masks.astype(bool),
912
- class_id=class_ids
913
- )
914
 
915
- box_annotator = sv.BoxAnnotator()
916
- label_annotator = sv.LabelAnnotator()
917
- mask_annotator = sv.MaskAnnotator()
918
-
919
- annotated_frame = input_image.copy()
920
- annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
921
- annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
922
- annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
923
 
924
- # Create transparent mask for first detected object
925
- if len(detections) > 0:
926
- # Get first mask
927
- first_mask = detections.mask[0]
928
-
929
- # Get original RGB image
930
- img = input_image.copy()
931
- H, W, C = img.shape
932
-
933
- # Create RGBA image
934
- alpha = np.zeros((H, W, 1), dtype=np.uint8)
935
- alpha[first_mask] = 255
936
- rgba = np.dstack((img, alpha)).astype(np.uint8)
937
-
938
- # Crop to mask bounds to minimize image size
939
- y_indices, x_indices = np.where(first_mask)
940
- y_min, y_max = y_indices.min(), y_indices.max()
941
- x_min, x_max = x_indices.min(), x_indices.max()
942
-
943
- # Crop the RGBA image
944
- cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
945
-
946
- # Set extracted foreground for mask mover
947
- mask_mover.set_extracted_fg(cropped_rgba)
948
-
949
- return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
950
-
951
- return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
952
 
953
 
954
  block = gr.Blocks().queue()
 
854
  image_url = client.upload_file(tmpfile.name)
855
  os.remove(tmpfile.name)
856
 
 
 
 
 
 
 
 
 
 
857
  # Process detection results
858
  input_boxes = []
859
+ masks = []
860
  confidences = []
861
  class_names = []
862
  class_ids = []
863
 
864
+ if len(input_text) == 0:
865
+ task = DinoxTask(
866
+ image_url=image_url,
867
+ prompts=[TextPrompt(text="<prompt_free>")],
868
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
869
+ )
870
+
871
+ client.run_task(task)
872
+ predictions = task.result.objects
873
+ classes = [pred.category for pred in predictions]
874
+ classes = list(set(classes))
875
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
876
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
877
+
878
+ for idx, obj in enumerate(predictions):
879
+ input_boxes.append(obj.bbox)
880
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
881
+ confidences.append(obj.score)
882
+ cls_name = obj.category.lower().strip()
883
+ class_names.append(cls_name)
884
+ class_ids.append(class_name_to_id[cls_name])
885
+
886
+ boxes = np.array(input_boxes)
887
+ masks = np.array(masks)
888
+ class_ids = np.array(class_ids)
889
+ labels = [
890
+ f"{class_name} {confidence:.2f}"
891
+ for class_name, confidence
892
+ in zip(class_names, confidences)
893
+ ]
894
+ detections = sv.Detections(
895
+ xyxy=boxes,
896
+ mask=masks.astype(bool),
897
+ class_id=class_ids
898
+ )
899
 
900
+ box_annotator = sv.BoxAnnotator()
901
+ label_annotator = sv.LabelAnnotator()
902
+ mask_annotator = sv.MaskAnnotator()
903
 
904
+ annotated_frame = input_image.copy()
905
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
906
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
907
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
 
908
 
909
+ # Create transparent mask for first detected object
910
+ if len(detections) > 0:
911
+ # Get first mask
912
+ first_mask = detections.mask[0]
913
+
914
+ # Get original RGB image
915
+ img = input_image.copy()
916
+ H, W, C = img.shape
917
+
918
+ # Create RGBA image
919
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
920
+ alpha[first_mask] = 255
921
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
922
+
923
+ # Crop to mask bounds to minimize image size
924
+ y_indices, x_indices = np.where(first_mask)
925
+ y_min, y_max = y_indices.min(), y_indices.max()
926
+ x_min, x_max = x_indices.min(), x_indices.max()
927
+
928
+ # Crop the RGBA image
929
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
930
+
931
+ # Set extracted foreground for mask mover
932
+ mask_mover.set_extracted_fg(cropped_rgba)
933
+
934
+ return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
935
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
936
 
937
 
938
  block = gr.Blocks().queue()