Model can be loaded from local directory (#69)
Browse files* Model can be loaded from local directory
* applied flake8 linter
audiocraft/models/loaders.py
CHANGED
@@ -51,6 +51,10 @@ def _get_state_dict(
|
|
51 |
if os.path.isfile(file_or_url_or_id):
|
52 |
return torch.load(file_or_url_or_id, map_location=device)
|
53 |
|
|
|
|
|
|
|
|
|
54 |
elif file_or_url_or_id.startswith('https://'):
|
55 |
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
56 |
|
|
|
51 |
if os.path.isfile(file_or_url_or_id):
|
52 |
return torch.load(file_or_url_or_id, map_location=device)
|
53 |
|
54 |
+
if os.path.isdir(file_or_url_or_id):
|
55 |
+
file = f"{file_or_url_or_id}/{filename}"
|
56 |
+
return torch.load(file, map_location=device)
|
57 |
+
|
58 |
elif file_or_url_or_id.startswith('https://'):
|
59 |
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
60 |
|
audiocraft/models/musicgen.py
CHANGED
@@ -89,10 +89,11 @@ class MusicGen:
|
|
89 |
return MusicGen(name, compression_model, lm)
|
90 |
|
91 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
96 |
|
97 |
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
98 |
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
|
|
89 |
return MusicGen(name, compression_model, lm)
|
90 |
|
91 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
92 |
+
if not os.path.isfile(name) and not os.path.isdir(name):
|
93 |
+
raise ValueError(
|
94 |
+
f"{name} is not a valid checkpoint name. "
|
95 |
+
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
96 |
+
)
|
97 |
|
98 |
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
99 |
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|