Spaces:
Sleeping
Sleeping
Commit
·
9b8ec7b
1
Parent(s):
a025d2f
Fix downloaded file path
Browse files
streamlit_apps/app_utils/depth_model.py
CHANGED
@@ -39,13 +39,16 @@ class DPTDepth(BaseDepthModel):
|
|
39 |
weights_path = os.path.join("weights", weights_fname)
|
40 |
if not os.path.isfile(weights_path):
|
41 |
from huggingface_hub import hf_hub_download
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
omnidata_ckpt = torch.load(
|
45 |
weights_path,
|
46 |
map_location="cpu",
|
47 |
)
|
48 |
-
|
49 |
self.model = DPTDepthModel()
|
50 |
self.model.load_state_dict(omnidata_ckpt)
|
51 |
self.model: DPTDepthModel = self.model.to(device).eval()
|
|
|
39 |
weights_path = os.path.join("weights", weights_fname)
|
40 |
if not os.path.isfile(weights_path):
|
41 |
from huggingface_hub import hf_hub_download
|
42 |
+
|
43 |
+
downloaded_filepath = hf_hub_download(
|
44 |
+
repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname
|
45 |
+
)
|
46 |
+
os.system(f'mv "{downloaded_filepath}" weights')
|
47 |
omnidata_ckpt = torch.load(
|
48 |
weights_path,
|
49 |
map_location="cpu",
|
50 |
)
|
51 |
+
|
52 |
self.model = DPTDepthModel()
|
53 |
self.model.load_state_dict(omnidata_ckpt)
|
54 |
self.model: DPTDepthModel = self.model.to(device).eval()
|
streamlit_apps/app_utils/sod_selection_ui.py
CHANGED
@@ -29,14 +29,15 @@ def load_smultimae_model(
|
|
29 |
cfg = arg_cfg[sod_model_config_key]()
|
30 |
|
31 |
weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
|
32 |
-
ckpt_path = os.path.join(
|
33 |
-
"weights", weights_fname
|
34 |
-
)
|
35 |
print(ckpt_path)
|
36 |
if not os.path.isfile(ckpt_path):
|
37 |
from huggingface_hub import hf_hub_download
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
assert os.path.isfile(ckpt_path)
|
41 |
|
42 |
# sod_model = ModelPL.load_from_checkpoint(
|
|
|
29 |
cfg = arg_cfg[sod_model_config_key]()
|
30 |
|
31 |
weights_fname = f"s-multimae-{cfg.experiment_name}-top{top}.pth"
|
32 |
+
ckpt_path = os.path.join("weights", weights_fname)
|
|
|
|
|
33 |
print(ckpt_path)
|
34 |
if not os.path.isfile(ckpt_path):
|
35 |
from huggingface_hub import hf_hub_download
|
36 |
+
|
37 |
+
downloaded_filepath = hf_hub_download(
|
38 |
+
repo_id="RGBD-SOD/S-MultiMAE", filename=weights_fname
|
39 |
+
)
|
40 |
+
os.system(f'mv "{downloaded_filepath}" weights')
|
41 |
assert os.path.isfile(ckpt_path)
|
42 |
|
43 |
# sod_model = ModelPL.load_from_checkpoint(
|