hujiecpp commited on
Commit
67b7ca3
·
1 Parent(s): 735791c

init project

Browse files
Files changed (1) hide show
  1. modules/pe3r/models.py +6 -3
modules/pe3r/models.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import sys
3
  sys.path.append(os.path.abspath('./modules/ultralytics'))
4
 
5
- from transformers import AutoTokenizer, AutoModel, AutoProcessor
6
  from modules.mast3r.model import AsymmetricMASt3R
7
 
8
  # from modules.sam2.build_sam import build_sam2_video_predictor
@@ -32,10 +32,13 @@ class Models:
32
 
33
  # -- mobilesamv2 & sam1 --
34
  # SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
35
- SAM1_ENCODER_CKP = 'facebook/sam-vit-huge/model.safetensors'
36
  SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
37
  self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
38
- image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
 
 
 
39
  prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
40
  self.mobilesamv2.prompt_encoder = prompt_encoder
41
  self.mobilesamv2.mask_decoder = mask_decoder
 
2
  import sys
3
  sys.path.append(os.path.abspath('./modules/ultralytics'))
4
 
5
+ from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
6
  from modules.mast3r.model import AsymmetricMASt3R
7
 
8
  # from modules.sam2.build_sam import build_sam2_video_predictor
 
32
 
33
  # -- mobilesamv2 & sam1 --
34
  # SAM1_ENCODER_CKP = './checkpoints/sam_vit_h.pt'
35
+ # SAM1_ENCODER_CKP = 'facebook/sam-vit-huge/model.safetensors'
36
  SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
37
  self.mobilesamv2 = sam_model_registry['sam_vit_h'](None)
38
+ # image_encoder=sam_model_registry['sam_vit_h_encoder'](SAM1_ENCODER_CKP)
39
+ sam1 = SamModel.from_pretrained("facebook/sam-vit-huge", device=device)
40
+ image_encoder = sam1.image_encoder
41
+
42
  prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
43
  self.mobilesamv2.prompt_encoder = prompt_encoder
44
  self.mobilesamv2.mask_decoder = mask_decoder