thinh-researcher commited on
Commit
9b8ec7b
·
1 Parent(s): a025d2f

Fix downloaded file path

Browse files
streamlit_apps/app_utils/depth_model.py CHANGED
@@ -39,13 +39,16 @@ class DPTDepth(BaseDepthModel):
39
  weights_path = os.path.join("weights", weights_fname)
40
  if not os.path.isfile(weights_path):
41
  from huggingface_hub import hf_hub_download
42
- hf_hub_download(repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname)
43
- os.system(f"mv {weights_fname} weights")
 
 
 
44
  omnidata_ckpt = torch.load(
45
  weights_path,
46
  map_location="cpu",
47
  )
48
-
49
  self.model = DPTDepthModel()
50
  self.model.load_state_dict(omnidata_ckpt)
51
  self.model: DPTDepthModel = self.model.to(device).eval()
 
39
  weights_path = os.path.join("weights", weights_fname)
40
  if not os.path.isfile(weights_path):
41
  from huggingface_hub import hf_hub_download
42
+
43
+ downloaded_filepath = hf_hub_download(
44
+ repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname
45
+ )
46
+ os.system(f'mv "{downloaded_filepath}" weights')
47
  omnidata_ckpt = torch.load(
48
  weights_path,
49
  map_location="cpu",
50
  )
51
+
52
  self.model = DPTDepthModel()
53
  self.model.load_state_dict(omnidata_ckpt)
54
  self.model: DPTDepthModel = self.model.to(device).eval()
streamlit_apps/app_utils/sod_selection_ui.py CHANGED
@@ -29,14 +29,15 @@ def load_smultimae_model(
29
  cfg = arg_cfg[sod_model_config_key]()
30
 
31
  weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
32
- ckpt_path = os.path.join(
33
- "weights", weights_fname
34
- )
35
  print(ckpt_path)
36
  if not os.path.isfile(ckpt_path):
37
  from huggingface_hub import hf_hub_download
38
- hf_hub_download(repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname)
39
- os.system(f"mv {weights_fname} weights")
 
 
 
40
  assert os.path.isfile(ckpt_path)
41
 
42
  # sod_model = ModelPL.load_from_checkpoint(
 
29
  cfg = arg_cfg[sod_model_config_key]()
30
 
31
  weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
32
+ ckpt_path = os.path.join("weights", weights_fname)
 
 
33
  print(ckpt_path)
34
  if not os.path.isfile(ckpt_path):
35
  from huggingface_hub import hf_hub_download
36
+
37
+ downloaded_filepath = hf_hub_download(
38
+ repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname
39
+ )
40
+ os.system(f'mv "{downloaded_filepath}" weights')
41
  assert os.path.isfile(ckpt_path)
42
 
43
  # sod_model = ModelPL.load_from_checkpoint(