File size: 698 Bytes
ec0b3c1
 
 
 
 
 
 
44b03d2
 
 
 
 
ec0b3c1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

from typing import Dict, Any
from sam2.build_sam import build_sam2

CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"]

CHECKPOINTS = {
    "tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"],
    "small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"],
    "base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"],
    "large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"],
}


def load_models(device: torch.device) -> Dict[str, Any]:
    models = {}
    for key, (config, checkpoint) in CHECKPOINTS.items():
        models[key] = build_sam2(config, checkpoint, device=device, apply_postprocessing=False)
    return models