Commit
·
cfcb62d
1
Parent(s):
3bd7ccd
Update visual_chatgpt.py
Browse files- visual_chatgpt.py +3 -4
visual_chatgpt.py
CHANGED
@@ -797,7 +797,7 @@ class Segmenting:
|
|
797 |
print(f"Inintializing Segmentation to {device}")
|
798 |
self.device = device
|
799 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
800 |
-
self.model_checkpoint_path =
|
801 |
|
802 |
self.download_parameters()
|
803 |
self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
|
@@ -813,7 +813,6 @@ class Segmenting:
|
|
813 |
print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
|
814 |
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
815 |
if not os.path.exists(path):
|
816 |
-
os.makedirs(path)
|
817 |
print("我进来了!")
|
818 |
# wget.download(url,out=self.model_checkpoint_path)
|
819 |
wget.download(url,out=path)
|
@@ -918,8 +917,8 @@ class Text2Box:
|
|
918 |
print(f"Initializing ObjectDetection to {device}")
|
919 |
self.device = device
|
920 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
921 |
-
self.model_checkpoint_path =
|
922 |
-
self.model_config_path =
|
923 |
self.download_parameters()
|
924 |
self.box_threshold = 0.3
|
925 |
self.text_threshold = 0.25
|
|
|
797 |
print(f"Inintializing Segmentation to {device}")
|
798 |
self.device = device
|
799 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
800 |
+
self.model_checkpoint_path = "/repository/checkpoints/sam"
|
801 |
|
802 |
self.download_parameters()
|
803 |
self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
|
|
|
813 |
print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
|
814 |
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
815 |
if not os.path.exists(path):
|
|
|
816 |
print("我进来了!")
|
817 |
# wget.download(url,out=self.model_checkpoint_path)
|
818 |
wget.download(url,out=path)
|
|
|
917 |
print(f"Initializing ObjectDetection to {device}")
|
918 |
self.device = device
|
919 |
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
|
920 |
+
self.model_checkpoint_path = "repository/checkpoints/groundingdino"
|
921 |
+
self.model_config_path = "repository/checkpoints/grounding_config.py"
|
922 |
self.download_parameters()
|
923 |
self.box_threshold = 0.3
|
924 |
self.text_threshold = 0.25
|