File size: 1,775 Bytes
ab854b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
# Ultralytics YOLO 🚀, AGPL-3.0 license
"""
SAM model interface
"""
from pathlib import Path
from ultralytics.engine.model import Model
from ultralytics.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor
class SAM(Model):
"""
SAM model interface.
"""
def __init__(self, model='sam_b.pt') -> None:
if model and Path(model).suffix not in ('.pt', '.pth'):
raise NotImplementedError('SAM prediction requires pre-trained *.pt or *.pth model.')
super().__init__(model=model, task='segment')
def _load(self, weights: str, task=None):
self.model = build_sam(weights)
def predict(self, source, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Predicts and returns segmentation masks for given image or video source."""
overrides = dict(conf=0.25, task='segment', mode='predict', imgsz=1024)
kwargs.update(overrides)
prompts = dict(bboxes=bboxes, points=points, labels=labels)
return super().predict(source, stream, prompts=prompts, **kwargs)
def __call__(self, source=None, stream=False, bboxes=None, points=None, labels=None, **kwargs):
"""Calls the 'predict' function with given arguments to perform object detection."""
return self.predict(source, stream, bboxes, points, labels, **kwargs)
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
@property
def task_map(self):
return {'segment': {'predictor': Predictor}}
|