Zhibinhong commited on
Commit
cfcb62d
·
1 Parent(s): 3bd7ccd

Update visual_chatgpt.py

Browse files
Files changed (1) hide show
  1. visual_chatgpt.py +3 -4
visual_chatgpt.py CHANGED
@@ -797,7 +797,7 @@ class Segmenting:
797
  print(f"Inintializing Segmentation to {device}")
798
  self.device = device
799
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
800
- self.model_checkpoint_path = os.path.join("checkpoints","sam")
801
 
802
  self.download_parameters()
803
  self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
@@ -813,7 +813,6 @@ class Segmenting:
813
  print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
814
  url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
815
  if not os.path.exists(path):
816
- os.makedirs(path)
817
  print("我进来了!")
818
  # wget.download(url,out=self.model_checkpoint_path)
819
  wget.download(url,out=path)
@@ -918,8 +917,8 @@ class Text2Box:
918
  print(f"Initializing ObjectDetection to {device}")
919
  self.device = device
920
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
921
- self.model_checkpoint_path = os.path.join("checkpoints","groundingdino")
922
- self.model_config_path = os.path.join("checkpoints","grounding_config.py")
923
  self.download_parameters()
924
  self.box_threshold = 0.3
925
  self.text_threshold = 0.25
 
797
  print(f"Inintializing Segmentation to {device}")
798
  self.device = device
799
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
800
+ self.model_checkpoint_path = "/repository/checkpoints/sam"
801
 
802
  self.download_parameters()
803
  self.sam = build_sam(checkpoint=self.model_checkpoint_path).to(device)
 
813
  print("finddir",os.system("find /repository -type d -iname 'checkpoints'"))
814
  url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
815
  if not os.path.exists(path):
 
816
  print("我进来了!")
817
  # wget.download(url,out=self.model_checkpoint_path)
818
  wget.download(url,out=path)
 
917
  print(f"Initializing ObjectDetection to {device}")
918
  self.device = device
919
  self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32
920
+ self.model_checkpoint_path = "repository/checkpoints/groundingdino"
921
+ self.model_config_path = "repository/checkpoints/grounding_config.py"
922
  self.download_parameters()
923
  self.box_threshold = 0.3
924
  self.text_threshold = 0.25