Spaces:
Build error
Build error
import numpy.typing as npt | |
import time | |
# from .sam import build_sam, SamPredictor | |
# from .sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor | |
# from .mobile_sam import ( | |
# build_sam_vit_t as build_mobile_sam, | |
# SamPredictor as MobileSamPredictor, | |
# ) | |
# from .per_sam import train, PerSAM | |
# from .configs import DEVICE | |
from app.sam import build_sam, SamPredictor | |
from app.sam_hq import build_sam as build_sam_hq, SamPredictor as SamHqPredictor | |
from app.mobile_sam import ( | |
build_sam_vit_t as build_mobile_sam, | |
SamPredictor as MobileSamPredictor, | |
) | |
from app.per_sam import train, PerSAM | |
from app.configs import DEVICE | |
def build_sam_predictor(checkpoint: str | None = None): | |
sam = build_sam(checkpoint) | |
sam = sam.to(DEVICE) | |
return SamPredictor(sam) | |
def build_sam_hq_predictor(checkpoint: str | None = None): | |
sam = build_sam_hq(checkpoint) | |
sam = sam.to(DEVICE) | |
return SamHqPredictor(sam) | |
def build_mobile_sam_predictor(checkpoint: str | None = None): | |
sam = build_mobile_sam(checkpoint) | |
sam = sam.to(DEVICE) | |
return MobileSamPredictor(sam) | |
def get_multi_label_predictor( | |
sam: MobileSamPredictor, image: npt.NDArray, mask: npt.NDArray, | |
) -> PerSAM: | |
start = time.perf_counter() | |
weights, target_feat = train(sam, [image], [mask]) | |
print(f"training time {time.perf_counter() - start}") | |
per_sam_model = PerSAM(sam, target_feat, 10, 0.4, 0.2, weights) | |
return per_sam_model | |
if __name__ == "__main__": | |
import numpy as np | |
from PIL import Image | |
from torchvision.transforms.functional import resize | |
from app.transforms import ResizeLongestSide | |
T = ResizeLongestSide(1024) | |
image = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.png").convert("RGB") | |
target_size = T.get_preprocess_shape(image.size[1], image.size[0], T.target_length) | |
image_np = np.array(resize(image, target_size)) | |
mask = Image.open("/Users/dillonlaird/code/instance_labeler/seals-labeled/img2.seal.10.png") | |
target_size = T.get_preprocess_shape(mask.size[1], mask.size[0], T.target_length) | |
mask_np = np.array(resize(mask, target_size).convert("L")) | |
model = build_mobile_sam_predictor("/Users/dillonlaird/code/instance_labeler/mobile_sam.pth") | |
start = time.perf_counter() | |
per_sam_model = get_multi_label_predictor(model, image_np, mask_np) | |
print(f"training time {time.perf_counter() - start}") | |
start = time.perf_counter() | |
masks, bboxes, _ = per_sam_model(image_np) | |
print(f"prediction time {time.perf_counter() - start}") | |