File size: 2,634 Bytes
1b65314
 
4d17f72
1b65314
67b7ca3
1b65314
 
562c140
1b65314
 
 
624a2ea
735791c
 
624a2ea
1b65314
735791c
 
 
1b65314
3b40174
 
1b65314
 
 
3b40174
624a2ea
 
 
 
562c140
1b65314
 
3b40174
67b7ca3
1b65314
 
67b7ca3
224be48
4addd4f
67b7ca3
1b65314
 
 
 
 
 
 
 
4d17f72
 
1b65314
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import sys
sys.path.append(os.path.abspath('./modules/ultralytics'))

from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
from modules.mast3r.model import AsymmetricMASt3R

# from modules.sam2.build_sam import build_sam2_video_predictor
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
from modules.mobilesamv2 import sam_model_registry

from sam2.sam2_video_predictor import SAM2VideoPredictor
import spaces
import torch

class Models:
    @spaces.GPU
    def __init__(self, device='cpu'):
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        # -- mast3r --
        # MAST3R_CKP = './checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
        MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
        self.mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)

        # -- sam2 --
        # SAM2_CKP = "./checkpoints/sam2.1_hiera_large.pt"
        # SAM2_CKP = 'hujiecpp/sam2-1-hiera-large'
        # SAM2_CONFIG = "./configs/sam2.1/sam2.1_hiera_l.yaml"
        # self.sam2 = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CKP, device=device, apply_postprocessing=False)
        # self.sam2.eval()
        self.sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)

        # -- mobilesamv2 & sam1 --
        # SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
        # SAM1_ENCODER_CKP = 'facebook/sam-vit-huge/model.safetensors'
        SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
        self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
        # image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
        sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
        image_encoder = sam1.vision_encoder

        prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
        self.mobilesamv2.prompt_encoder = prompt_encoder
        self.mobilesamv2.mask_decoder = mask_decoder
        self.mobilesamv2.image_encoder=image_encoder
        self.mobilesamv2.to(device=device)
        self.mobilesamv2.eval()

        # -- yolov8 --
        YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
        self.yolov8 = ObjectAwareModel(YOLO8_CKP)

        # -- siglip --
        self.siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
        self.siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256", device_map=device)
        self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256", device_map=device)