Avoid concatenating query words?

#6
by geooff - opened

Hello fellow HuggingFacers

When calling the model I notice that query words can be joined together leading to nonsense outputs.

With an input like:
cat. dog. fish. red..

Along with other more desired detections that are direct matches to my search I also get undesired detections like:
cat fish. red dog.

This automatic crossing of queries leads to the model being unusable for higher cardinality object detection. This issue is also discussed here: https://github.com/IDEA-Research/GroundingDINO/issues/85

I'm calling the model as follows:

class DinoDetector:
    """Handles object detection using Grounding DINO model."""

    def __init__(
        self,
        model_id: str = "IDEA-Research/grounding-dino-tiny",
        device: Optional[str] = None,
    ):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        self.device = device
        self.model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(
            device
        )
        self.processor = AutoProcessor.from_pretrained(model_id)
        self.logger = logging.get_logger("transformers")
        logging.set_verbosity_info()

    def detect(self, image: Image.Image, threshold: float = 0.3) -> DinoOutput:
        """
        Detect clothing items in an image using all configured prompts.

        Args:
            image: PIL Image to process
            threshold: Confidence threshold for detections

        Returns:
            DinoOutput containing detections with both prompts and labels
        """
        # Get all prompts and their labels
        label_prompts = DinoPrompts.get_all_prompts()

        # Create mapping of prompts to labels
        prompt_to_label = {prompt: label for label, prompt in label_prompts}
        prompts = " ".join(list(prompt_to_label.keys()))

        inputs = self.processor(
            images=image, text=prompts, return_tensors="pt"  # Use all clothing prompts
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        # Run detection
        results = self.processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=threshold,
            text_threshold=threshold,
            target_sizes=[image.size[::-1]],
        )

        results = DinoOutput.from_model_output(results, prompt_to_label)

        if not results.detections:
        else:
            for det in results.detections:
        return results

When calling the model I notice that query words can be joined together leading to nonsense outputs.

Hey, could you give me a sample image and prompt that you got a nonsense result?

This automatic crossing of queries leads to the model being unusable for higher cardinality object detection.

This is a known limitation of GroundingDino if you have high cardinality you have to increase the number of passes through the model with different prompts that should contain all your classes. There's a model that worked to fix this called APE I had plans to add to the transformers library, but unfortunately I don't have time for this right now :/

Thanks for your response. I actually have a solution. The papers repo has the optional param remove_combined in its predict method

I've yoinked their remove_combined logic and wrote my own implementation of post_process_grounded_object_detection. I'll open a PR if I get some time

    def post_process_grounded_object_detection(
        self,
        outputs,
        input_ids=None,
        box_threshold=0.25,
        text_threshold=0.25,
        target_sizes=None,
        text_labels=None,
    ):
        """
        Converts the raw output of [`GroundingDinoForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
        bottom_right_x, bottom_right_y) format and get the associated text label.

        Args:
            outputs ([`GroundingDinoObjectDetectionOutput`]):
                Raw outputs of the model.
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
                The token ids of the input text. If not provided will be taken from the model output.
            threshold (`float`, *optional*, defaults to 0.25):
                Threshold to keep object detection predictions based on confidence score.
            text_threshold (`float`, *optional*, defaults to 0.25):
                Score threshold to keep text detection predictions.
            target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
                Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
                `(height, width)` of each image in the batch. If unset, predictions will not be resized.
            text_labels (`List[List[str]]`, *optional*):
                List of candidate labels to be detected on each image. At the moment it's *NOT used*, but required
                to be in signature for the zero-shot object detection pipeline. Text labels are instead extracted
                from the `input_ids` tensor provided in `outputs`.

        Returns:
            `List[Dict]`: A list of dictionaries, each dictionary containing the
                - **scores**: tensor of confidence scores for detected objects
                - **boxes**: tensor of bounding boxes in [x0, y0, x1, y1] format
                - **labels**: list of text labels for each detected object (will be replaced with integer ids in v4.51.0)
                - **text_labels**: list of text labels for detected objects
        """
        batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
        input_ids = input_ids if input_ids is not None else outputs.input_ids

        if target_sizes is not None and len(target_sizes) != len(batch_logits):
            raise ValueError(
                "Make sure that you pass in as many target sizes as the batch dimension of the logits"
            )

        batch_probs = torch.sigmoid(batch_logits)  # (batch_size, num_queries, 256)
        batch_scores = torch.max(batch_probs, dim=-1)[0]  # (batch_size, num_queries)

        # Convert to [x0, y0, x1, y1] format
        batch_boxes = center_to_corners_format(batch_boxes)

        # Convert from relative [0, 1] to absolute [0, height] coordinates
        if target_sizes is not None:
            if isinstance(target_sizes, List):
                img_h = torch.Tensor([i[0] for i in target_sizes])
                img_w = torch.Tensor([i[1] for i in target_sizes])
            else:
                img_h, img_w = target_sizes.unbind(1)

            scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(
                batch_boxes.device
            )
            batch_boxes = batch_boxes * scale_fct[:, None, :]

        results = []
        for idx, (scores, boxes, probs) in enumerate(
            zip(batch_scores, batch_boxes, batch_probs)
        ):
            keep = scores > box_threshold
            scores = scores[keep]
            boxes = boxes[keep]

            # Extract text labels respecting boundaries
            prob = probs[keep]
            objects_text_labels = []

            for single_prob in prob:
                # Find separator token positions (CLS, SEP, period)
                sep_positions = [
                    i for i, id in enumerate(input_ids[idx]) if id in [101, 102, 1012]
                ]

                # Find positions where probability exceeds threshold
                active_positions = (single_prob > text_threshold).nonzero().squeeze(1)

                # Group activations by separator boundaries
                phrase_ids = []
                for i in range(len(sep_positions) - 1):
                    start, end = sep_positions[i], sep_positions[i + 1]

                    # Check if we have activations in this segment
                    mask = (active_positions >= start) & (active_positions < end)
                    if mask.any():
                        # Extract ids for this segment only
                        segment_ids = input_ids[idx][start:end]
                        decoded_phrase = self.processor.decode(segment_ids)
                        # Clean up the decoded phrase
                        cleaned_phrase = (
                            decoded_phrase.replace("[CLS]", "")
                            .replace("[SEP]", "")
                            .strip()
                        )
                        if cleaned_phrase:
                            objects_text_labels.append(cleaned_phrase)

            result = {
                "scores": scores,
                "boxes": boxes,
                "text_labels": objects_text_labels,
                "labels": objects_text_labels,  # Keeping both for compatibility
            }
            results.append(result)

        return results

Sign up or log in to comment