init project
Browse files- modules/pe3r/models.py +7 -4
modules/pe3r/models.py
CHANGED
@@ -9,6 +9,8 @@ from modules.sam2.build_sam import build_sam2_video_predictor
|
|
9 |
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
|
10 |
from modules.mobilesamv2 import sam_model_registry
|
11 |
|
|
|
|
|
12 |
class Models:
|
13 |
def __init__(self, device):
|
14 |
# -- mast3r --
|
@@ -18,10 +20,11 @@ class Models:
|
|
18 |
|
19 |
# -- sam2 --
|
20 |
# SAM2_CKP = "./checkpoints/sam2.1_hiera_large.pt"
|
21 |
-
SAM2_CKP = 'hujiecpp/sam2-1-hiera-large'
|
22 |
-
SAM2_CONFIG = "./configs/sam2.1/sam2.1_hiera_l.yaml"
|
23 |
-
self.sam2 = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CKP, device=device, apply_postprocessing=False)
|
24 |
-
self.sam2.eval()
|
|
|
25 |
|
26 |
# -- mobilesamv2 & sam1 --
|
27 |
# SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
|
|
|
9 |
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
|
10 |
from modules.mobilesamv2 import sam_model_registry
|
11 |
|
12 |
+
from sam2.sam2_video_predictor import SAM2VideoPredictor
|
13 |
+
|
14 |
class Models:
|
15 |
def __init__(self, device):
|
16 |
# -- mast3r --
|
|
|
20 |
|
21 |
# -- sam2 --
|
22 |
# SAM2_CKP = "./checkpoints/sam2.1_hiera_large.pt"
|
23 |
+
# SAM2_CKP = 'hujiecpp/sam2-1-hiera-large'
|
24 |
+
# SAM2_CONFIG = "./configs/sam2.1/sam2.1_hiera_l.yaml"
|
25 |
+
# self.sam2 = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CKP, device=device, apply_postprocessing=False)
|
26 |
+
# self.sam2.eval()
|
27 |
+
self.sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large')
|
28 |
|
29 |
# -- mobilesamv2 & sam1 --
|
30 |
# SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
|