import os import cv2 import torch import argparse import numpy as np import supervision as sv from PIL import Image import gc import sys from eval.grounded_sam.florence2.modeling_florence2 import Florence2ForConditionalGeneration from eval.grounded_sam.florence2.processing_florence2 import Florence2Processor from eval.grounded_sam.sam2.build_sam import build_sam2 from eval.grounded_sam.sam2.sam2_image_predictor import SAM2ImagePredictor class FlorenceSAM: # official usage: https://huggingface.co/microsoft/Florence-2-large/blob/main/sample_inference.ipynb TASK_PROMPT = { "original": "", "caption": "", "detailed_caption": "", "more_detailed_caption": "", "object_detection": "", "dense_region_caption": "", "region_proposal": "", "phrase_grounding": "", "referring_expression_segmentation": "", "region_to_segmentation": "", "open_vocabulary_detection": "", "region_to_category": "", "region_to_description": "", "ocr": "", "ocr_with_region": "", } def __init__(self, device): """ Init Florence-2 and SAM 2 Model """ print(f"[{self}] init on device {device}") self.device = torch.device(device) # with torch.autocast(device_type="cuda", dtype=torch.float32).__enter__() # self.torch_dtype = torch.float32 # self.torch_dtype = torch.float16 self.torch_dtype = torch.bfloat16 try: if torch.cuda.get_device_properties(0).major >= 8: # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # self.torch_dtype = torch.bfloat16 # else: # self.torch_dtype = torch.float16 except: self.torch_dtype = torch.bfloat16 FLORENCE2_MODEL_ID = os.getenv('FLORENCE2_MODEL_PATH') print(f'FLORENCE2_MODEL_ID is {FLORENCE2_MODEL_ID}') SAM2_CHECKPOINT = os.getenv('SAM2_MODEL_PATH') SAM2_CONFIG = "configs/sam2.1/sam2.1_hiera_l.yaml" self.florence2_model = Florence2ForConditionalGeneration.from_pretrained( FLORENCE2_MODEL_ID, trust_remote_code=True, local_files_only=True, torch_dtype=self.torch_dtype, ).eval().to(self.device) self.florence2_processor = Florence2Processor.from_pretrained( FLORENCE2_MODEL_ID, trust_remote_code=True, local_files_only=True, ) sam2_model = build_sam2(SAM2_CONFIG, SAM2_CHECKPOINT, device=self.device) self.sam2_predictor = SAM2ImagePredictor(sam2_model) def __str__(self): return "FlorenceSAM" @torch.no_grad() def run_florence2(self, task_prompt, text_input, image): model = self.florence2_model processor = self.florence2_processor device = self.device assert model is not None, "You should pass the init florence-2 model here" assert processor is not None, "You should set florence-2 processor here" with torch.autocast(device_type="cuda", dtype=torch.float32): if text_input is None: prompt = task_prompt else: prompt = task_prompt + text_input inputs = processor( text=prompt, images=image, max_length=1024, truncation=True, return_tensors="pt", ).to(device, self.torch_dtype) # inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, self.torch_dtype) generated_ids = model.generate( input_ids=inputs["input_ids"].to(device), pixel_values=inputs["pixel_values"].to(device), # max_new_tokens=1024, max_new_tokens=768, early_stopping=False, do_sample=False, num_beams=3, ) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] parsed_answer = processor.post_process_generation( generated_text, task=task_prompt, image_size=(image.width, image.height) ) return parsed_answer def caption(self, image, caption_task_prompt=''): assert caption_task_prompt in ["", "", ""] caption_results = self.run_florence2(caption_task_prompt, None, image) text_input = caption_results[caption_task_prompt] caption = text_input return caption def segmentation(self, image, input_boxes, seg_model="sam"): if seg_model == "sam": with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float32): sam2_predictor = self.sam2_predictor sam2_predictor.set_image(np.array(image)) masks, scores, logits = sam2_predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) if masks.ndim == 4: masks = masks.squeeze(1) if scores.ndim == 2: scores = scores.squeeze(1) else: raise NotImplementedError() return masks, scores def post_process_results(self, image, caption, labels, detections, output_dir=None): result_dict = { "caption": caption, "instance_images": [], "instance_labels": [], "instance_bboxes": [], "instance_mask_scores": [], } if detections is None: return detections, result_dict if output_dir is not None: os.makedirs(output_dir, exist_ok=True) cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) box_annotator = sv.BoxAnnotator() annotated_frame = box_annotator.annotate(scene=cv_image.copy(), detections=detections) label_annotator = sv.LabelAnnotator(text_position=sv.Position.CENTER) annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels) if output_dir is not None: cv2.imwrite(os.path.join(output_dir, "detections.jpg"), annotated_frame) mask_annotator = sv.MaskAnnotator() annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) if output_dir is not None: cv2.imwrite(os.path.join(output_dir, "masks.jpg"), annotated_frame) for detection in detections: xyxy, mask, confidence, class_id, tracker_id, data = detection label = labels[class_id] cropped_img = sv.crop_image(image=cv_image, xyxy=xyxy) if output_dir is not None: cv2.imwrite(os.path.join(output_dir, f"cropped_image_{label}.jpg"), cropped_img) if mask is None: result_dict["instance_mask_scores"].append(0) result_dict["instance_images"].append(cropped_img) else: mask = np.repeat(mask[..., np.newaxis], 3, axis=-1) masked_img = np.where(mask, cv_image, 255) cropped_masked_img = sv.crop_image(image=masked_img, xyxy=xyxy) result_dict["instance_mask_scores"].append(confidence.item()) result_dict["instance_images"].append(cropped_masked_img) result_dict["instance_labels"].append(label) result_dict["instance_bboxes"].append(xyxy) if output_dir is not None: cv2.imwrite(os.path.join(output_dir, f"masked_image_{label}.jpg"), cropped_masked_img) torch.cuda.empty_cache() gc.collect() return detections, result_dict def caption_phrase_grounding_and_segmentation( self, image, seg_model="sam", caption_task_prompt='', original_caption=None, output_dir=None ): assert caption_task_prompt in ["", "", "", "", ""] assert seg_model in ["sam", "florence2"] # image caption if caption_task_prompt in ["", ""]: assert original_caption is not None caption = original_caption else: caption_results = self.run_florence2(caption_task_prompt, None, image) text_input = caption_results[caption_task_prompt] caption = text_input # phrase grounding grounding_results = self.run_florence2('', caption, image)[''] input_boxes = np.array(grounding_results["bboxes"]) class_names = grounding_results["labels"] class_ids = np.array(list(range(len(class_names)))) # segmentation masks, scores = self.segmentation(image, input_boxes, seg_model) labels = [f"{class_name}" for class_name in class_names] detections = sv.Detections( xyxy=input_boxes, mask=masks.astype(bool), class_id=class_ids, confidence=scores, ) return self.post_process_results(image, caption, labels, detections, output_dir) def od_grounding_and_segmentation( self, image, text_input, seg_model="sam", output_dir=None ): assert seg_model in ["sam", "florence2"] # od grounding grounding_results = self.run_florence2('', text_input, image)[''] if len(grounding_results["bboxes"]) == 0: detections = None labels = [] else: input_boxes = np.array(grounding_results["bboxes"]) class_names = grounding_results["bboxes_labels"] class_ids = np.array(list(range(len(class_names)))) # segmentation masks, scores = self.segmentation(image, input_boxes, seg_model) labels = [f"{class_name}" for class_name in class_names] detections = sv.Detections( xyxy=input_boxes, mask=masks.astype(bool), class_id=class_ids, confidence=scores, ) return self.post_process_results(image, text_input, labels, detections, output_dir) def od_grounding( self, image, text_input, output_dir=None ): # od grounding grounding_results = self.run_florence2('', text_input, image)[''] if len(grounding_results["bboxes"]) == 0: detections = None labels = [] else: input_boxes = np.array(grounding_results["bboxes"]) class_names = grounding_results["bboxes_labels"] class_ids = np.array(list(range(len(class_names)))) labels = [f"{class_name}" for class_name in class_names] detections = sv.Detections( xyxy=input_boxes, class_id=class_ids, ) return self.post_process_results(image, text_input, labels, detections, output_dir) def phrase_grounding_and_segmentation( self, image, text_input, seg_model="sam", output_dir=None ): assert seg_model in ["sam", "florence2"] # phrase grounding grounding_results = self.run_florence2('', text_input, image)[''] input_boxes = np.array(grounding_results["bboxes"]) class_names = grounding_results["labels"] # print(f"[phrase_grounding_and_segmentation] input_label={text_input}, output_label={class_names}") class_ids = np.array(list(range(len(class_names)))) # segmentation masks, scores = self.segmentation(image, input_boxes, seg_model) labels = [f"{class_name}" for class_name in class_names] detections = sv.Detections( xyxy=input_boxes, mask=masks.astype(bool), class_id=class_ids, confidence=scores, ) return self.post_process_results(image, text_input, labels, detections, output_dir) if __name__ == "__main__": parser = argparse.ArgumentParser("Grounded SAM 2 Florence-2 Demos", add_help=True) parser.add_argument("--image_path", type=str, default="./notebooks/images/cars.jpg", required=True, help="path to image file") parser.add_argument("--caption_type", type=str, default="caption", required=False, help="granularity of caption") args = parser.parse_args() # IMAGE_PATH = args.image_path PIPELINE = "caption_to_phrase_grounding" CAPTION_TYPE = args.caption_type assert CAPTION_TYPE in ["caption", "detailed_caption", "more_detailed_caption", "original"] print(f"Running pipeline: {PIPELINE} now.") pipeline = FlorenceSAM("cuda:0") from glob import glob from tqdm import tqdm for image_path in tqdm(glob("/mnt/bn/lq-prompt-alignment/personal/chenbowen/code/IPVerse/prompt_alignment/Grounded-SAM-2/notebooks/images/*") * 3): # for image_path in tqdm(glob("/mnt/bn/lq-prompt-alignment/personal/chenbowen/code/IPVerse/prompt_alignment/Grounded-SAM-2/outputs/gcg_pipeline/00001.tar_debug/*.png")): print(pipeline.TASK_PROMPT, CAPTION_TYPE) image = Image.open(image_path).convert("RGB") pipeline.caption_phrase_grounding_and_segmentation( image=image, seg_model="sam", caption_task_prompt=pipeline.TASK_PROMPT[CAPTION_TYPE], output_dir=f"./outputs/{os.path.basename(image_path)}" )