Spaces:
Sleeping
Sleeping
# modified from https://github.com/IDEA-Research/GroundingDINO/blob/main/demo/test_ap_on_coco.py | |
import argparse | |
import os | |
import sys | |
import time | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.utils.data import DataLoader, DistributedSampler | |
from groundingdino.models import build_model | |
import groundingdino.datasets.transforms as T | |
from groundingdino.util import box_ops, get_tokenlizer | |
from groundingdino.util.misc import clean_state_dict, collate_fn | |
from groundingdino.util.slconfig import SLConfig | |
# from torchvision.datasets import CocoDetection | |
import torchvision | |
from groundingdino.util.vl_utils import build_captions_and_token_span, create_positive_map_from_span | |
from groundingdino.datasets.cocogrounding_eval import CocoGroundingEvaluator | |
# segment anything | |
from segment_anything import ( | |
build_sam, | |
build_sam_hq, | |
SamPredictor | |
) | |
import cv2 | |
import json | |
def load_model(model_config_path: str, model_checkpoint_path: str, device: str = "cuda"): | |
args = SLConfig.fromfile(model_config_path) | |
args.device = device | |
model = build_model(args) | |
checkpoint = torch.load(model_checkpoint_path, map_location="cpu") | |
model.load_state_dict(clean_state_dict(checkpoint["model"]), strict=False) | |
model.eval() | |
return model | |
class CocoDetection(torchvision.datasets.CocoDetection): | |
def __init__(self, img_folder, ann_file, transforms): | |
super().__init__(img_folder, ann_file) | |
self._transforms = transforms | |
def __getitem__(self, idx): | |
img, target = super().__getitem__(idx) # target: list | |
w, h = img.size | |
boxes = [obj["bbox"] for obj in target] | |
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) | |
boxes[:, 2:] += boxes[:, :2] # xywh -> xyxy | |
boxes[:, 0::2].clamp_(min=0, max=w) | |
boxes[:, 1::2].clamp_(min=0, max=h) | |
# filt invalid boxes/masks/keypoints | |
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) | |
boxes = boxes[keep] | |
target_new = {} | |
image_id = self.ids[idx] | |
target_new["image_id"] = image_id | |
target_new["boxes"] = boxes | |
target_new["orig_size"] = torch.as_tensor([int(h), int(w)]) | |
target_new['file_path'] = self.coco.imgs[image_id]['file_name'] | |
if self._transforms is not None: | |
img, target = self._transforms(img, target_new) | |
return img, target | |
class PostProcessSeginw(nn.Module): | |
""" This module converts the model's output into the format expected by the coco api""" | |
def __init__(self, num_select=300, coco_api=None, tokenlizer=None) -> None: | |
super().__init__() | |
self.num_select = num_select | |
assert coco_api is not None | |
category_dict = coco_api.dataset['categories'] | |
cat_list = [item['name'] for item in category_dict] | |
# captions, cat2tokenspan = build_captions_and_token_span(cat_list, True) | |
captions, cat2tokenspan = build_captions_and_token_span(cat_list, False) | |
tokenspanlist = [cat2tokenspan[cat] for cat in cat_list] | |
positive_map = create_positive_map_from_span( | |
tokenlizer(captions), tokenspanlist) # 80, 256. normed | |
self.positive_map = positive_map | |
def forward(self, outputs, target_sizes, not_to_xyxy=False): | |
""" Perform the computation | |
Parameters: | |
outputs: raw outputs of the model | |
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch | |
For evaluation, this must be the original image size (before any data augmentation) | |
For visualization, this should be the image size after data augment, but before padding | |
""" | |
num_select = self.num_select | |
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] | |
# pos map to logit | |
prob_to_token = out_logits.sigmoid() # bs, 100, 256 | |
pos_maps = self.positive_map.to(prob_to_token.device) | |
# (bs, 100, 256) @ (91, 256).T -> (bs, 100, 91) | |
prob_to_label = prob_to_token @ pos_maps.T | |
# if os.environ.get('IPDB_SHILONG_DEBUG', None) == 'INFO': | |
# import ipdb; ipdb.set_trace() | |
assert len(out_logits) == len(target_sizes) | |
assert target_sizes.shape[1] == 2 | |
prob = prob_to_label | |
topk_values, topk_indexes = torch.topk( | |
prob.view(out_logits.shape[0], -1), num_select, dim=1) | |
scores = topk_values | |
topk_boxes = topk_indexes // prob.shape[2] | |
labels = topk_indexes % prob.shape[2] | |
if not_to_xyxy: | |
boxes = out_bbox | |
else: | |
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) | |
boxes = torch.gather( | |
boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) | |
# and from relative [0, 1] to absolute [0, height] coordinates | |
img_h, img_w = target_sizes.unbind(1) | |
scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) | |
boxes = boxes * scale_fct[:, None, :] | |
results = [{'scores': s, 'labels': l, 'boxes': b} | |
for s, l, b in zip(scores, labels, boxes)] | |
return results | |
def main(args): | |
# config | |
cfg = SLConfig.fromfile(args.config_file) | |
# build model | |
model = load_model(args.config_file, args.checkpoint_path) | |
model = model.to(args.device) | |
model = model.eval() | |
# build dataloader | |
transform = T.Compose( | |
[ | |
T.RandomResize([800], max_size=1333), | |
T.ToTensor(), | |
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
] | |
) | |
dataset = CocoDetection( | |
args.image_dir, args.anno_path, transforms=transform) | |
data_loader = DataLoader( | |
dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) | |
# build post processor | |
tokenlizer = get_tokenlizer.get_tokenlizer(cfg.text_encoder_type) | |
postprocessor = PostProcessSeginw(num_select=args.num_select,coco_api=dataset.coco, tokenlizer=tokenlizer) | |
# build evaluator | |
evaluator = CocoGroundingEvaluator( | |
dataset.coco, iou_types=("bbox","segm"), useCats=True) | |
# build captions | |
category_dict = dataset.coco.dataset['categories'] | |
cat_list = [item['name'] for item in category_dict] | |
caption = " . ".join(cat_list) + ' .' | |
print("Input text prompt:", caption) | |
# SAM | |
use_sam_hq = args.use_sam_hq | |
if use_sam_hq: | |
sam_hq_checkpoint = args.sam_hq_checkpoint | |
predictor = SamPredictor(build_sam_hq(checkpoint=sam_hq_checkpoint).to(args.device)) | |
else: | |
sam_checkpoint = args.sam_checkpoint | |
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(args.device)) | |
json_file = [] | |
# run inference | |
start = time.time() | |
for i, (images, targets) in enumerate(data_loader): | |
# get images and captions | |
images = images.tensors.to(args.device) | |
bs = images.shape[0] | |
assert bs == 1 | |
input_captions = [caption] * bs | |
# feed to the model | |
outputs = model(images, captions=input_captions) | |
orig_target_sizes = torch.stack( | |
[t["orig_size"] for t in targets], dim=0).to(images.device) | |
results = postprocessor(outputs, orig_target_sizes) | |
sam_image = cv2.imread(args.image_dir+targets[0]['file_path']) | |
sam_image = cv2.cvtColor(sam_image, cv2.COLOR_BGR2RGB) | |
predictor.set_image(sam_image) | |
input_boxes = results[0]['boxes'].cpu() | |
transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, sam_image.shape[:2]).to(args.device) | |
masks, _, _ = predictor.predict_torch( | |
point_coords = None, | |
point_labels = None, | |
boxes = transformed_boxes, | |
multimask_output = False, | |
) | |
results[0]['masks'] = masks.cpu().numpy() | |
cocogrounding_res = { | |
target["image_id"]: output for target, output in zip(targets, results)} | |
save_items = evaluator.update(cocogrounding_res) | |
if args.save_json: | |
new_items = list() | |
for item in save_items: | |
new_item = dict() | |
new_item['image_id'] = item['image_id'] | |
new_item['category_id'] = item['category_id'] | |
new_item['segmentation'] = item['segmentation'] | |
new_item['score'] = item['score'] | |
new_items.append(new_item) | |
json_file = json_file + new_items | |
if (i+1) % 30 == 0: | |
used_time = time.time() - start | |
eta = len(data_loader) / (i+1e-5) * used_time - used_time | |
print( | |
f"processed {i}/{len(data_loader)} images. time: {used_time:.2f}s, ETA: {eta:.2f}s") | |
evaluator.synchronize_between_processes() | |
evaluator.accumulate() | |
evaluator.summarize() | |
print("Final results:", evaluator.coco_eval["segm"].stats.tolist()) | |
if args.save_json: | |
if args.use_sam_hq: | |
os.makedirs('seginw_output/sam_hq/', exist_ok=True) | |
save_path = 'seginw_output/sam_hq/seginw-'+args.anno_path.split('/')[-3]+'_val.json' | |
else: | |
os.makedirs('seginw_output/sam/', exist_ok=True) | |
save_path = 'seginw_output/sam/seginw-'+args.anno_path.split('/')[-3]+'_val.json' | |
with open(save_path,'w') as f: | |
json.dump(json_file,f) | |
print(save_path) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
"Grounding DINO eval on COCO", add_help=True) | |
# load model | |
parser.add_argument("--config_file", "-c", type=str, | |
required=True, help="path to config file") | |
parser.add_argument( | |
"--checkpoint_path", "-p", type=str, required=True, help="path to checkpoint file" | |
) | |
parser.add_argument("--device", type=str, default="cuda", | |
help="running device (default: cuda)") | |
# post processing | |
parser.add_argument("--num_select", type=int, default=100, | |
help="number of topk to select") | |
# coco info | |
parser.add_argument("--anno_path", type=str, | |
required=True, help="coco root") | |
parser.add_argument("--image_dir", type=str, | |
required=True, help="coco image dir") | |
parser.add_argument("--num_workers", type=int, default=4, | |
help="number of workers for dataloader") | |
# SAM | |
parser.add_argument( | |
"--sam_checkpoint", type=str, default='pretrained_checkpoint/sam_vit_h_4b8939.pth', help="path to sam checkpoint file" | |
) | |
parser.add_argument( | |
"--sam_hq_checkpoint", type=str, default='pretrained_checkpoint/sam_hq_vit_h.pth', help="path to sam-hq checkpoint file" | |
) | |
parser.add_argument( | |
"--use_sam_hq", action="store_true", help="using sam-hq for prediction" | |
) | |
# Save json result | |
parser.add_argument( | |
"--save_json", action="store_true", help="saving json result for evaluation" | |
) | |
args = parser.parse_args() | |
main(args) | |