File size: 2,634 Bytes
1b65314 4d17f72 1b65314 67b7ca3 1b65314 562c140 1b65314 624a2ea 735791c 624a2ea 1b65314 735791c 1b65314 3b40174 1b65314 3b40174 624a2ea 562c140 1b65314 3b40174 67b7ca3 1b65314 67b7ca3 224be48 4addd4f 67b7ca3 1b65314 4d17f72 1b65314 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
import os
import sys
sys.path.append(os.path.abspath('./modules/ultralytics'))
from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
from modules.mast3r.model import AsymmetricMASt3R
# from modules.sam2.build_sam import build_sam2_video_predictor
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
from modules.mobilesamv2 import sam_model_registry
from sam2.sam2_video_predictor import SAM2VideoPredictor
import spaces
import torch
class Models:
@spaces.GPU
def __init__(self, device='cpu'):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# -- mast3r --
# MAST3R_CKP = './checkpoints/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth'
MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
self.mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
# -- sam2 --
# SAM2_CKP = "./checkpoints/sam2.1_hiera_large.pt"
# SAM2_CKP = 'hujiecpp/sam2-1-hiera-large'
# SAM2_CONFIG = "./configs/sam2.1/sam2.1_hiera_l.yaml"
# self.sam2 = build_sam2_video_predictor(SAM2_CONFIG, SAM2_CKP, device=device, apply_postprocessing=False)
# self.sam2.eval()
self.sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
# -- mobilesamv2 & sam1 --
# SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
# SAM1_ENCODER_CKP = 'facebook/sam-vit-huge/model.safetensors'
SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
# image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
image_encoder = sam1.vision_encoder
prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
self.mobilesamv2.prompt_encoder = prompt_encoder
self.mobilesamv2.mask_decoder = mask_decoder
self.mobilesamv2.image_encoder=image_encoder
self.mobilesamv2.to(device=device)
self.mobilesamv2.eval()
# -- yolov8 --
YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
self.yolov8 = ObjectAwareModel(YOLO8_CKP)
# -- siglip --
self.siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
self.siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256", device_map=device)
self.siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256", device_map=device) |