SkalskiP's picture
initial version of SAM2 space
ec0b3c1
raw
history blame
698 Bytes
import torch
from typing import Dict, Any
from sam2.build_sam import build_sam2
CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]
CHECKPOINTS = {
"tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
"small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
"base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
"large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
}
def load_models(device: torch.device) -> Dict[str, Any]:
models = {}
for key, (config, checkpoint) in CHECKPOINTS.items():
models[key] = build_sam2(config, checkpoint, device=device, apply_postprocessing=False)
return models