|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
from dataclasses import fields |
|
from hashlib import sha1 |
|
from pathlib import Path |
|
from typing import ( |
|
Any, |
|
Dict, |
|
List, |
|
Optional, |
|
Tuple, |
|
Type, |
|
TypeVar, |
|
Union, |
|
cast, |
|
) |
|
from urllib.parse import urlparse |
|
|
|
import torch |
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
import audioseal |
|
from audioseal.builder import ( |
|
AudioSealDetectorConfig, |
|
AudioSealWMConfig, |
|
create_detector, |
|
create_generator, |
|
) |
|
from audioseal.models import AudioSealDetector, AudioSealWM |
|
|
|
AudioSealT = TypeVar("AudioSealT", AudioSealWMConfig, AudioSealDetectorConfig) |
|
|
|
|
|
class ModelLoadError(RuntimeError): |
|
"""Raised when the model loading fails""" |
|
|
|
|
|
def _get_path_from_env(var_name: str) -> Optional[Path]: |
|
pathname = os.getenv(var_name) |
|
if not pathname: |
|
return None |
|
|
|
try: |
|
return Path(pathname) |
|
except ValueError as ex: |
|
raise RuntimeError(f"Expect valid pathname, get '{pathname}'.") from ex |
|
|
|
|
|
def _get_cache_dir(env_names: List[str]): |
|
"""Re-use cache dir from a list of existing caches""" |
|
for env in env_names: |
|
cache_dir = _get_path_from_env(env) |
|
if cache_dir: |
|
break |
|
else: |
|
cache_dir = Path("~/.cache").expanduser().resolve() |
|
|
|
|
|
cache_dir = cache_dir / "audioseal" |
|
cache_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
return cache_dir |
|
|
|
|
|
def load_model_checkpoint( |
|
model_path: Union[Path, str], |
|
device: Union[str, torch.device] = "cpu", |
|
): |
|
if Path(model_path).is_file(): |
|
return torch.load(model_path, map_location=device) |
|
|
|
cache_dir = _get_cache_dir( |
|
["AUDIOSEAL_CACHE_DIR", "AUDIOCRAFT_CACHE_DIR", "XDG_CACHE_HOME"] |
|
) |
|
parts = urlparse(str(model_path)) |
|
if parts.scheme == "https": |
|
|
|
hash_ = sha1(parts.path.encode()).hexdigest()[:24] |
|
return torch.hub.load_state_dict_from_url( |
|
str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_ |
|
) |
|
elif str(model_path).startswith("facebook/audioseal/"): |
|
hf_filename = str(model_path)[len("facebook/audioseal/") :] |
|
|
|
try: |
|
from huggingface_hub import hf_hub_download |
|
except ModuleNotFoundError: |
|
print( |
|
f"The model path {model_path} seems to be a direct HF path, " |
|
"but you do not install Huggingface_hub. Install with for example " |
|
"`pip install huggingface_hub` to use this feature." |
|
) |
|
file = hf_hub_download( |
|
repo_id="facebook/audioseal", |
|
filename=hf_filename, |
|
cache_dir=cache_dir, |
|
library_name="audioseal", |
|
library_version=audioseal.__version__, |
|
) |
|
return torch.load(file, map_location=device) |
|
else: |
|
raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist") |
|
|
|
|
|
def load_local_model_config(model_card: str) -> Optional[DictConfig]: |
|
config_file = Path(__file__).parent / "cards" / (model_card + ".yaml") |
|
if Path(config_file).is_file(): |
|
return cast(DictConfig, OmegaConf.load(config_file.resolve())) |
|
else: |
|
return None |
|
|
|
|
|
class AudioSeal: |
|
|
|
@staticmethod |
|
def parse_model( |
|
model_card_or_path: str, |
|
model_type: Type[AudioSealT], |
|
nbits: Optional[int] = None, |
|
) -> Tuple[Dict[str, Any], AudioSealT]: |
|
""" |
|
Parse the information from the model card or checkpoint path using |
|
the schema `model_type` that defines the model type |
|
""" |
|
|
|
config = load_local_model_config(model_card_or_path) |
|
|
|
if config: |
|
assert "checkpoint" in config, f"Checkpoint missing in {model_card_or_path}" |
|
config_dict = OmegaConf.to_container(config) |
|
assert isinstance( |
|
config_dict, dict |
|
), f"Cannot parse config from {model_card_or_path}" |
|
checkpoint = config_dict.pop("checkpoint") |
|
checkpoint = load_model_checkpoint(checkpoint) |
|
|
|
|
|
else: |
|
config_dict = {} |
|
checkpoint = load_model_checkpoint(model_card_or_path) |
|
|
|
if "xp.cfg" in checkpoint: |
|
config_dict = {**checkpoint["xp.cfg"], **config_dict} |
|
|
|
model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) |
|
|
|
if "model" in checkpoint: |
|
checkpoint = checkpoint["model"] |
|
|
|
return checkpoint, model_config |
|
|
|
@staticmethod |
|
def parse_config( |
|
config: Dict[str, Any], |
|
config_type: Type[AudioSealT], |
|
nbits: Optional[int] = None, |
|
) -> AudioSealT: |
|
|
|
assert "seanet" in config, f"missing seanet backbone config in {config}" |
|
|
|
|
|
config = OmegaConf.create(config) |
|
OmegaConf.resolve(config) |
|
config = OmegaConf.to_container(config) |
|
|
|
|
|
seanet_config = config["seanet"] |
|
for key_to_patch in ["encoder", "decoder", "detector"]: |
|
if key_to_patch in seanet_config: |
|
config_to_patch = config.get(key_to_patch) or {} |
|
config[key_to_patch] = { |
|
**config_to_patch, |
|
**seanet_config.pop(key_to_patch), |
|
} |
|
|
|
config["seanet"] = seanet_config |
|
|
|
|
|
if nbits and "nbits" not in config: |
|
config["nbits"] = nbits |
|
|
|
|
|
result_config = {} |
|
assert config, f"Empty config" |
|
for field in fields(config_type): |
|
if field.name in config: |
|
result_config[field.name] = config[field.name] |
|
|
|
schema = OmegaConf.structured(config_type) |
|
schema.merge_with(result_config) |
|
return schema |
|
|
|
@staticmethod |
|
def load_generator( |
|
model_card_or_path: str, |
|
nbits: Optional[int] = None, |
|
) -> AudioSealWM: |
|
"""Load the AudioSeal generator from the model card""" |
|
checkpoint, config = AudioSeal.parse_model( |
|
model_card_or_path, |
|
AudioSealWMConfig, |
|
nbits=nbits, |
|
) |
|
|
|
model = create_generator(config) |
|
model.load_state_dict(checkpoint) |
|
return model |
|
|
|
@staticmethod |
|
def load_detector( |
|
model_card_or_path: str, |
|
nbits: Optional[int] = None, |
|
) -> AudioSealDetector: |
|
checkpoint, config = AudioSeal.parse_model( |
|
model_card_or_path, |
|
AudioSealDetectorConfig, |
|
nbits=nbits, |
|
) |
|
model = create_detector(config) |
|
model.load_state_dict(checkpoint) |
|
return model |
|
|