Spaces:
Running
Running
File size: 2,028 Bytes
5238467 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
"""Utility for loading the models from HF."""
import os
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf
from huggingface_hub import hf_hub_download, login
import torch
from audiocraft.models import builders, MusicGen
MODEL_CHECKPOINTS_MAP = {
"small": "facebook/musicgen-small",
"medium": "facebook/musicgen-medium",
"large": "facebook/musicgen-large",
"melody": "facebook/musicgen-melody",
}
login(os.environ['ACCESS_TOKEN'])
def _get_state_dict(file_or_url: tp.Union[Path, str],
filename="state_dict.bin", device='cpu'):
# Return the state dict either from a file or url
print("loading", file_or_url, filename)
file_or_url = str(file_or_url)
assert isinstance(file_or_url, str)
return torch.load(
hf_hub_download(repo_id=file_or_url, filename=filename), map_location=device)
def load_compression_model(file_or_url: tp.Union[Path, str], device='cpu'):
pkg = _get_state_dict(file_or_url, filename="compression_state_dict.bin")
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
model = builders.get_compression_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def load_lm_model(file_or_url: tp.Union[Path, str], device='cpu'):
pkg = _get_state_dict(file_or_url)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.transformer_lm.memory_efficient = False
cfg.transformer_lm.custom = True
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
model = builders.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model
def get_pretrained(name: str = 'small', device='cuda'):
model_id = MODEL_CHECKPOINTS_MAP[name]
compression_model = load_compression_model(model_id, device=device)
lm = load_lm_model(model_id, device=device)
return MusicGen(name, compression_model, lm)
|