ScouterAI / modal_apps /segment_anything.py
stevenbucaille's picture
Enhance app.py with improved user interface and instructions, update model ID in llm.py, and add image classification capabilities across various components. Introduce segment anything functionality and refine README for clarity on model capabilities.
518d841
from typing import List, Tuple
import modal
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import SamHQModel, SamHQProcessor
from .app import app
from .image import image
def get_detections_from_segment_anything(bounding_boxes, list_of_masks, iou_scores):
detections = sv.Detections(
xyxy=np.array(bounding_boxes),
mask=np.array(list_of_masks),
class_id=np.array(list(range(len(bounding_boxes)))),
confidence=np.array(iou_scores),
)
return detections
@app.cls(gpu="T4", image=image)
class SegmentAnythingModalApp:
@modal.enter()
def setup(self):
model_name = "syscv-community/sam-hq-vit-huge"
# model_name = "syscv-community/sam-hq-vit-base"
self.model = SamHQModel.from_pretrained(model_name)
self.model.to("cuda")
self.model.eval()
self.processor: SamHQProcessor = SamHQProcessor.from_pretrained(model_name)
@modal.method()
def forward(self, image, bounding_boxes: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
bounding_boxes = [bounding_boxes]
inputs = self.processor(image, input_boxes=bounding_boxes, return_tensors="pt")
inputs = inputs.to("cuda")
with torch.no_grad():
outputs = self.model(**inputs)
pred_masks = outputs.pred_masks
original_sizes = inputs["original_sizes"]
reshaped_input_sizes = inputs["reshaped_input_sizes"]
batched_masks = self.processor.post_process_masks(pred_masks, original_sizes, reshaped_input_sizes)
batched_masks = batched_masks[0][:, 0]
iou_scores = outputs.iou_scores[0, :, 0]
batched_masks = batched_masks.cpu().numpy()
iou_scores = iou_scores.cpu().numpy()
return batched_masks, iou_scores
if __name__ == "__main__":
import supervision as sv
from PIL import Image
segment_anything = modal.Cls.from_name(app.name, SegmentAnythingModalApp.__name__)()
image = Image.open("images/image.png")
bounding_boxes = [
[2449, 2021, 2758, 2359],
[436, 1942, 1002, 2193],
[2785, 1945, 3259, 2374],
[3285, 1996, 3721, 2405],
[1968, 2035, 2451, 2474],
[1741, 1909, 2098, 2320],
]
masks, iou_scores = segment_anything.forward.remote(image=image, bounding_boxes=bounding_boxes)
detections = get_detections_from_segment_anything(bounding_boxes, masks, iou_scores)
mask_annotations = sv.MaskAnnotator()
annotated_image = mask_annotations.annotate(scene=image, detections=detections)
annotated_image.save("images/resized_image_with_detections_annotated.png")