adding support for cpu
Browse files
audiocraft/models/loaders.py
CHANGED
@@ -80,8 +80,6 @@ def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_di
|
|
80 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
81 |
cfg.device = str(device)
|
82 |
if cfg.device == 'cpu':
|
83 |
-
cfg.transformer_lm.memory_efficient = False
|
84 |
-
cfg.transformer_lm.custom = True
|
85 |
cfg.dtype = 'float32'
|
86 |
else:
|
87 |
cfg.dtype = 'float16'
|
|
|
80 |
cfg = OmegaConf.create(pkg['xp.cfg'])
|
81 |
cfg.device = str(device)
|
82 |
if cfg.device == 'cpu':
|
|
|
|
|
83 |
cfg.dtype = 'float32'
|
84 |
else:
|
85 |
cfg.dtype = 'float16'
|
audiocraft/models/musicgen.py
CHANGED
@@ -68,7 +68,7 @@ class MusicGen:
|
|
68 |
return self.compression_model.channels
|
69 |
|
70 |
@staticmethod
|
71 |
-
def get_pretrained(name: str = 'melody', device=
|
72 |
"""Return pretrained model, we provide four models:
|
73 |
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
74 |
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
@@ -76,11 +76,17 @@ class MusicGen:
|
|
76 |
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
77 |
"""
|
78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
if name == 'debug':
|
80 |
# used only for unit tests
|
81 |
compression_model = get_debug_compression_model(device)
|
82 |
lm = get_debug_lm_model(device)
|
83 |
-
return MusicGen(name, compression_model, lm
|
84 |
|
85 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
86 |
raise ValueError(
|
@@ -313,7 +319,6 @@ class MusicGen:
|
|
313 |
all_tokens.append(prompt_tokens)
|
314 |
prompt_length = prompt_tokens.shape[-1]
|
315 |
|
316 |
-
|
317 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
318 |
|
319 |
while current_gen_offset + prompt_length < total_gen_len:
|
|
|
68 |
return self.compression_model.channels
|
69 |
|
70 |
@staticmethod
|
71 |
+
def get_pretrained(name: str = 'melody', device=None):
|
72 |
"""Return pretrained model, we provide four models:
|
73 |
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
74 |
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
|
|
76 |
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
77 |
"""
|
78 |
|
79 |
+
if device is None:
|
80 |
+
if torch.cuda.device_count():
|
81 |
+
device = 'cuda'
|
82 |
+
else:
|
83 |
+
device = 'cpu'
|
84 |
+
|
85 |
if name == 'debug':
|
86 |
# used only for unit tests
|
87 |
compression_model = get_debug_compression_model(device)
|
88 |
lm = get_debug_lm_model(device)
|
89 |
+
return MusicGen(name, compression_model, lm)
|
90 |
|
91 |
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
92 |
raise ValueError(
|
|
|
319 |
all_tokens.append(prompt_tokens)
|
320 |
prompt_length = prompt_tokens.shape[-1]
|
321 |
|
|
|
322 |
stride_tokens = int(self.frame_rate * self.extend_stride)
|
323 |
|
324 |
while current_gen_offset + prompt_length < total_gen_len:
|
tests/models/test_musicgen.py
CHANGED
@@ -51,6 +51,7 @@ class TestSEANetModel:
|
|
51 |
|
52 |
def test_generate_long(self):
|
53 |
mg = self.get_musicgen()
|
|
|
54 |
mg.set_generation_params(duration=4., stride_extend=2.)
|
55 |
wav = mg.generate(
|
56 |
['youpi', 'lapin dort'])
|
|
|
51 |
|
52 |
def test_generate_long(self):
|
53 |
mg = self.get_musicgen()
|
54 |
+
mg.max_duration = 3.
|
55 |
mg.set_generation_params(duration=4., stride_extend=2.)
|
56 |
wav = mg.generate(
|
57 |
['youpi', 'lapin dort'])
|