File size: 2,897 Bytes
4479f79 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import torch
import cv2
from PIL import Image
from eval.grounded_sam.grounded_sam2_florence2_autolabel_pipeline import FlorenceSAM
class ObjectDetector:
def __init__(self, device):
self.device = torch.device(device)
self.detector = FlorenceSAM(device)
def get_instances(self, gen_image, label, min_size=64):
_, instance_result_dict = \
self.detector.od_grounding_and_segmentation(
image=gen_image, text_input=label,
)
instances = instance_result_dict["instance_images"]
filtered_instances = []
for img in instances:
width, height = img.shape[:2]
if width * height < min_size * min_size or min(width, height) < min_size // 4:
continue
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
filtered_instances.append(img)
return filtered_instances
def get_multiple_instances(self, gen_image, label, min_size=64):
# self.detector.phrase_grounding_and_segmentation(
_, instance_result_dict = \
self.detector.od_grounding_and_segmentation(
image=gen_image, text_input=label,
)
return instance_result_dict
if __name__ == "__main__":
# online demo: https://dun.163.com/trial/face/compare
from glob import glob
from tqdm import tqdm
from src.train.data.data_utils import split_grid, pad_to_square
from eval.idip.dino import DINOScore
detector = ObjectDetector("cuda")
dino_model = DINOScore("cuda")
gen_image = Image.open("assets/tests/20250320-151038.jpeg").convert("RGB")
label = "two people"
save_dir = f"tmp"
os.makedirs(save_dir, exist_ok=True)
# for i, img in enumerate(split_grid(gen_image)):
for i, img in enumerate([gen_image]):
found_ips = detector.get_instances(img, label, min_size=img.size[0]//20)[:3]
found_ips = [pad_to_square(x) for x in found_ips]
for j, ip in enumerate(found_ips):
# score = dino_model(real_image, ip)
score = 1
pad_to_square(ip).save(f"{save_dir}/{label}_{i}_{j}_{score}.png")
|