Avoid concatenating query words?
Hello fellow HuggingFacers
When calling the model I notice that query words can be joined together leading to nonsense outputs.
With an input like:
cat. dog. fish. red..
Along with other more desired detections that are direct matches to my search I also get undesired detections like:cat fish. red dog.
This automatic crossing of queries leads to the model being unusable for higher cardinality object detection. This issue is also discussed here: https://github.com/IDEA-Research/GroundingDINO/issues/85
I'm calling the model as follows:
class DinoDetector:
"""Handles object detection using Grounding DINO model."""
def __init__(
self,
model_id: str = "IDEA-Research/grounding-dino-tiny",
device: Optional[str] = None,
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
self.model = GroundingDinoForObjectDetection.from_pretrained(model_id).to(
device
)
self.processor = AutoProcessor.from_pretrained(model_id)
self.logger = logging.get_logger("transformers")
logging.set_verbosity_info()
def detect(self, image: Image.Image, threshold: float = 0.3) -> DinoOutput:
"""
Detect clothing items in an image using all configured prompts.
Args:
image: PIL Image to process
threshold: Confidence threshold for detections
Returns:
DinoOutput containing detections with both prompts and labels
"""
# Get all prompts and their labels
label_prompts = DinoPrompts.get_all_prompts()
# Create mapping of prompts to labels
prompt_to_label = {prompt: label for label, prompt in label_prompts}
prompts = " ".join(list(prompt_to_label.keys()))
inputs = self.processor(
images=image, text=prompts, return_tensors="pt" # Use all clothing prompts
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
# Run detection
results = self.processor.post_process_grounded_object_detection(
outputs,
inputs.input_ids,
box_threshold=threshold,
text_threshold=threshold,
target_sizes=[image.size[::-1]],
)
results = DinoOutput.from_model_output(results, prompt_to_label)
if not results.detections:
else:
for det in results.detections:
return results
When calling the model I notice that query words can be joined together leading to nonsense outputs.
Hey, could you give me a sample image and prompt that you got a nonsense result?
This automatic crossing of queries leads to the model being unusable for higher cardinality object detection.
This is a known limitation of GroundingDino
if you have high cardinality you have to increase the number of passes through the model with different prompts that should contain all your classes. There's a model that worked to fix this called APE
I had plans to add to the transformers library, but unfortunately I don't have time for this right now :/
Thanks for your response. I actually have a solution. The papers repo has the optional param remove_combined
in its predict method
I've yoinked their remove_combined
logic and wrote my own implementation of post_process_grounded_object_detection. I'll open a PR if I get some time
def post_process_grounded_object_detection(
self,
outputs,
input_ids=None,
box_threshold=0.25,
text_threshold=0.25,
target_sizes=None,
text_labels=None,
):
"""
Converts the raw output of [`GroundingDinoForObjectDetection`] into final bounding boxes in (top_left_x, top_left_y,
bottom_right_x, bottom_right_y) format and get the associated text label.
Args:
outputs ([`GroundingDinoObjectDetectionOutput`]):
Raw outputs of the model.
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
The token ids of the input text. If not provided will be taken from the model output.
threshold (`float`, *optional*, defaults to 0.25):
Threshold to keep object detection predictions based on confidence score.
text_threshold (`float`, *optional*, defaults to 0.25):
Score threshold to keep text detection predictions.
target_sizes (`torch.Tensor` or `List[Tuple[int, int]]`, *optional*):
Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size
`(height, width)` of each image in the batch. If unset, predictions will not be resized.
text_labels (`List[List[str]]`, *optional*):
List of candidate labels to be detected on each image. At the moment it's *NOT used*, but required
to be in signature for the zero-shot object detection pipeline. Text labels are instead extracted
from the `input_ids` tensor provided in `outputs`.
Returns:
`List[Dict]`: A list of dictionaries, each dictionary containing the
- **scores**: tensor of confidence scores for detected objects
- **boxes**: tensor of bounding boxes in [x0, y0, x1, y1] format
- **labels**: list of text labels for each detected object (will be replaced with integer ids in v4.51.0)
- **text_labels**: list of text labels for detected objects
"""
batch_logits, batch_boxes = outputs.logits, outputs.pred_boxes
input_ids = input_ids if input_ids is not None else outputs.input_ids
if target_sizes is not None and len(target_sizes) != len(batch_logits):
raise ValueError(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
batch_probs = torch.sigmoid(batch_logits) # (batch_size, num_queries, 256)
batch_scores = torch.max(batch_probs, dim=-1)[0] # (batch_size, num_queries)
# Convert to [x0, y0, x1, y1] format
batch_boxes = center_to_corners_format(batch_boxes)
# Convert from relative [0, 1] to absolute [0, height] coordinates
if target_sizes is not None:
if isinstance(target_sizes, List):
img_h = torch.Tensor([i[0] for i in target_sizes])
img_w = torch.Tensor([i[1] for i in target_sizes])
else:
img_h, img_w = target_sizes.unbind(1)
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(
batch_boxes.device
)
batch_boxes = batch_boxes * scale_fct[:, None, :]
results = []
for idx, (scores, boxes, probs) in enumerate(
zip(batch_scores, batch_boxes, batch_probs)
):
keep = scores > box_threshold
scores = scores[keep]
boxes = boxes[keep]
# Extract text labels respecting boundaries
prob = probs[keep]
objects_text_labels = []
for single_prob in prob:
# Find separator token positions (CLS, SEP, period)
sep_positions = [
i for i, id in enumerate(input_ids[idx]) if id in [101, 102, 1012]
]
# Find positions where probability exceeds threshold
active_positions = (single_prob > text_threshold).nonzero().squeeze(1)
# Group activations by separator boundaries
phrase_ids = []
for i in range(len(sep_positions) - 1):
start, end = sep_positions[i], sep_positions[i + 1]
# Check if we have activations in this segment
mask = (active_positions >= start) & (active_positions < end)
if mask.any():
# Extract ids for this segment only
segment_ids = input_ids[idx][start:end]
decoded_phrase = self.processor.decode(segment_ids)
# Clean up the decoded phrase
cleaned_phrase = (
decoded_phrase.replace("[CLS]", "")
.replace("[SEP]", "")
.strip()
)
if cleaned_phrase:
objects_text_labels.append(cleaned_phrase)
result = {
"scores": scores,
"boxes": boxes,
"text_labels": objects_text_labels,
"labels": objects_text_labels, # Keeping both for compatibility
}
results.append(result)
return results