ZeqiangLai commited on
Commit
cc9db82
·
verified ·
1 Parent(s): 8c7926b

Update hy3dgen/shapegen/utils.py

Browse files
Files changed (1) hide show
  1. hy3dgen/shapegen/utils.py +37 -0
hy3dgen/shapegen/utils.py CHANGED
@@ -70,3 +70,40 @@ class synchronize_timer:
70
  return result
71
 
72
  return wrapper
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  return result
71
 
72
  return wrapper
73
+
74
+
75
+ def smart_load_model(
76
+ model_path,
77
+ subfolder,
78
+ use_safetensors,
79
+ variant,
80
+ ):
81
+ original_model_path = model_path
82
+ # try local path
83
+ base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
84
+ model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
85
+ logger.info(f'Try to load model from local path: {model_path}')
86
+ if not os.path.exists(model_path):
87
+ logger.info('Model path not exists, try to download from huggingface')
88
+ try:
89
+ import huggingface_hub
90
+ # download from huggingface
91
+ path = huggingface_hub.snapshot_download(repo_id=original_model_path)
92
+ model_path = os.path.join(path, subfolder)
93
+ except ImportError:
94
+ logger.warning(
95
+ "You need to install HuggingFace Hub to load models from the hub."
96
+ )
97
+ raise RuntimeError(f"Model path {model_path} not found")
98
+ except Exception as e:
99
+ raise e
100
+
101
+ if not os.path.exists(model_path):
102
+ raise FileNotFoundError(f"Model path {original_model_path} not found")
103
+
104
+ extension = 'ckpt' if not use_safetensors else 'safetensors'
105
+ variant = '' if variant is None else f'.{variant}'
106
+ ckpt_name = f'model{variant}.{extension}'
107
+ config_path = os.path.join(model_path, 'config.yaml')
108
+ ckpt_path = os.path.join(model_path, ckpt_name)
109
+ return config_path, ckpt_path