Spaces:
Runtime error
Runtime error
Update gradio_demo.py
Browse files- gradio_demo.py +69 -85
gradio_demo.py
CHANGED
@@ -854,101 +854,85 @@ def process_image(input_image, input_text):
|
|
854 |
image_url = client.upload_file(tmpfile.name)
|
855 |
os.remove(tmpfile.name)
|
856 |
|
857 |
-
# Run DINO-X detection
|
858 |
-
task = DinoxTask(
|
859 |
-
image_url=image_url,
|
860 |
-
prompts=[TextPrompt(text=input_text)]
|
861 |
-
)
|
862 |
-
client.run_task(task)
|
863 |
-
result = task.result
|
864 |
-
objects = result.objects
|
865 |
-
|
866 |
# Process detection results
|
867 |
input_boxes = []
|
|
|
868 |
confidences = []
|
869 |
class_names = []
|
870 |
class_ids = []
|
871 |
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
|
|
|
|
|
|
904 |
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
)
|
914 |
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
919 |
-
|
920 |
-
|
921 |
-
|
922 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
923 |
|
924 |
-
# Create transparent mask for first detected object
|
925 |
-
if len(detections) > 0:
|
926 |
-
# Get first mask
|
927 |
-
first_mask = detections.mask[0]
|
928 |
-
|
929 |
-
# Get original RGB image
|
930 |
-
img = input_image.copy()
|
931 |
-
H, W, C = img.shape
|
932 |
-
|
933 |
-
# Create RGBA image
|
934 |
-
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
935 |
-
alpha[first_mask] = 255
|
936 |
-
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
937 |
-
|
938 |
-
# Crop to mask bounds to minimize image size
|
939 |
-
y_indices, x_indices = np.where(first_mask)
|
940 |
-
y_min, y_max = y_indices.min(), y_indices.max()
|
941 |
-
x_min, x_max = x_indices.min(), x_indices.max()
|
942 |
-
|
943 |
-
# Crop the RGBA image
|
944 |
-
cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
|
945 |
-
|
946 |
-
# Set extracted foreground for mask mover
|
947 |
-
mask_mover.set_extracted_fg(cropped_rgba)
|
948 |
-
|
949 |
-
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
950 |
-
|
951 |
-
return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
|
952 |
|
953 |
|
954 |
block = gr.Blocks().queue()
|
|
|
854 |
image_url = client.upload_file(tmpfile.name)
|
855 |
os.remove(tmpfile.name)
|
856 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
857 |
# Process detection results
|
858 |
input_boxes = []
|
859 |
+
masks = []
|
860 |
confidences = []
|
861 |
class_names = []
|
862 |
class_ids = []
|
863 |
|
864 |
+
if len(input_text) == 0:
|
865 |
+
task = DinoxTask(
|
866 |
+
image_url=image_url,
|
867 |
+
prompts=[TextPrompt(text="<prompt_free>")],
|
868 |
+
# targets=[DetectionTarget.BBox, DetectionTarget.Mask]
|
869 |
+
)
|
870 |
+
|
871 |
+
client.run_task(task)
|
872 |
+
predictions = task.result.objects
|
873 |
+
classes = [pred.category for pred in predictions]
|
874 |
+
classes = list(set(classes))
|
875 |
+
class_name_to_id = {name: id for id, name in enumerate(classes)}
|
876 |
+
class_id_to_name = {id: name for name, id in class_name_to_id.items()}
|
877 |
+
|
878 |
+
for idx, obj in enumerate(predictions):
|
879 |
+
input_boxes.append(obj.bbox)
|
880 |
+
masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
|
881 |
+
confidences.append(obj.score)
|
882 |
+
cls_name = obj.category.lower().strip()
|
883 |
+
class_names.append(cls_name)
|
884 |
+
class_ids.append(class_name_to_id[cls_name])
|
885 |
+
|
886 |
+
boxes = np.array(input_boxes)
|
887 |
+
masks = np.array(masks)
|
888 |
+
class_ids = np.array(class_ids)
|
889 |
+
labels = [
|
890 |
+
f"{class_name} {confidence:.2f}"
|
891 |
+
for class_name, confidence
|
892 |
+
in zip(class_names, confidences)
|
893 |
+
]
|
894 |
+
detections = sv.Detections(
|
895 |
+
xyxy=boxes,
|
896 |
+
mask=masks.astype(bool),
|
897 |
+
class_id=class_ids
|
898 |
+
)
|
899 |
|
900 |
+
box_annotator = sv.BoxAnnotator()
|
901 |
+
label_annotator = sv.LabelAnnotator()
|
902 |
+
mask_annotator = sv.MaskAnnotator()
|
903 |
|
904 |
+
annotated_frame = input_image.copy()
|
905 |
+
annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
|
906 |
+
annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
|
907 |
+
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
|
|
|
908 |
|
909 |
+
# Create transparent mask for first detected object
|
910 |
+
if len(detections) > 0:
|
911 |
+
# Get first mask
|
912 |
+
first_mask = detections.mask[0]
|
913 |
+
|
914 |
+
# Get original RGB image
|
915 |
+
img = input_image.copy()
|
916 |
+
H, W, C = img.shape
|
917 |
+
|
918 |
+
# Create RGBA image
|
919 |
+
alpha = np.zeros((H, W, 1), dtype=np.uint8)
|
920 |
+
alpha[first_mask] = 255
|
921 |
+
rgba = np.dstack((img, alpha)).astype(np.uint8)
|
922 |
+
|
923 |
+
# Crop to mask bounds to minimize image size
|
924 |
+
y_indices, x_indices = np.where(first_mask)
|
925 |
+
y_min, y_max = y_indices.min(), y_indices.max()
|
926 |
+
x_min, x_max = x_indices.min(), x_indices.max()
|
927 |
+
|
928 |
+
# Crop the RGBA image
|
929 |
+
cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
|
930 |
+
|
931 |
+
# Set extracted foreground for mask mover
|
932 |
+
mask_mover.set_extracted_fg(cropped_rgba)
|
933 |
+
|
934 |
+
return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
|
935 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
936 |
|
937 |
|
938 |
block = gr.Blocks().queue()
|