|
import os |
|
import sys |
|
sys.path.append(os.path.abspath('./modules/ultralytics')) |
|
|
|
from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel |
|
from modules.mast3r.model import AsymmetricMASt3R |
|
|
|
|
|
from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel |
|
from modules.mobilesamv2 import sam_model_registry |
|
|
|
from sam2.sam2_video_predictor import SAM2VideoPredictor |
|
|
|
class Models: |
|
def __init__(self, device): |
|
|
|
|
|
MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric' |
|
self.mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device) |
|
|
|
|
|
|
|
|
|
SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt' |
|
self.mobilesamv2 = sam_model_registry['sam_vit_h'](None) |
|
|
|
sam1 = SamModel.from_pretrained('facebook/sam-vit-huge') |
|
image_encoder = sam1.vision_encoder |
|
|
|
prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP) |
|
self.mobilesamv2.prompt_encoder = prompt_encoder |
|
self.mobilesamv2.mask_decoder = mask_decoder |
|
self.mobilesamv2.image_encoder=image_encoder |
|
self.mobilesamv2.to(device=device) |
|
self.mobilesamv2.eval() |
|
|
|
|
|
YOLO8_CKP='./checkpoints/ObjectAwareModel.pt' |
|
self.yolov8 = ObjectAwareModel(YOLO8_CKP) |
|
|
|
|
|
|
|
|
|
|