test1 / src /audioseal /loader.py
Zw07's picture
Upload 14 files
0209786 verified
raw
history blame
7.2 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
from dataclasses import fields
from hashlib import sha1
from pathlib import Path
from typing import ( # type: ignore[attr-defined]
Any,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
cast,
)
from urllib.parse import urlparse # noqa: F401
import torch
from omegaconf import DictConfig, OmegaConf
import audioseal
from audioseal.builder import (
AudioSealDetectorConfig,
AudioSealWMConfig,
create_detector,
create_generator,
)
from audioseal.models import AudioSealDetector, AudioSealWM
AudioSealT = TypeVar("AudioSealT", AudioSealWMConfig, AudioSealDetectorConfig)
class ModelLoadError(RuntimeError):
"""Raised when the model loading fails"""
def _get_path_from_env(var_name: str) -> Optional[Path]:
pathname = os.getenv(var_name)
if not pathname:
return None
try:
return Path(pathname)
except ValueError as ex:
raise RuntimeError(f"Expect valid pathname, get '{pathname}'.") from ex
def _get_cache_dir(env_names: List[str]):
"""Re-use cache dir from a list of existing caches"""
for env in env_names:
cache_dir = _get_path_from_env(env)
if cache_dir:
break
else:
cache_dir = Path("~/.cache").expanduser().resolve()
# Create a sub-dir to not mess up with existing caches
cache_dir = cache_dir / "audioseal"
cache_dir.mkdir(exist_ok=True, parents=True)
return cache_dir
def load_model_checkpoint(
model_path: Union[Path, str],
device: Union[str, torch.device] = "cpu",
):
if Path(model_path).is_file():
return torch.load(model_path, map_location=device)
cache_dir = _get_cache_dir(
["AUDIOSEAL_CACHE_DIR", "AUDIOCRAFT_CACHE_DIR", "XDG_CACHE_HOME"]
)
parts = urlparse(str(model_path))
if parts.scheme == "https":
hash_ = sha1(parts.path.encode()).hexdigest()[:24]
return torch.hub.load_state_dict_from_url(
str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
)
elif str(model_path).startswith("facebook/audioseal/"):
hf_filename = str(model_path)[len("facebook/audioseal/") :]
try:
from huggingface_hub import hf_hub_download
except ModuleNotFoundError:
print(
f"The model path {model_path} seems to be a direct HF path, "
"but you do not install Huggingface_hub. Install with for example "
"`pip install huggingface_hub` to use this feature."
)
file = hf_hub_download(
repo_id="facebook/audioseal",
filename=hf_filename,
cache_dir=cache_dir,
library_name="audioseal",
library_version=audioseal.__version__,
)
return torch.load(file, map_location=device)
else:
raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")
def load_local_model_config(model_card: str) -> Optional[DictConfig]:
config_file = Path(__file__).parent / "cards" / (model_card + ".yaml")
if Path(config_file).is_file():
return cast(DictConfig, OmegaConf.load(config_file.resolve()))
else:
return None
class AudioSeal:
@staticmethod
def parse_model(
model_card_or_path: str,
model_type: Type[AudioSealT],
nbits: Optional[int] = None,
) -> Tuple[Dict[str, Any], AudioSealT]:
"""
Parse the information from the model card or checkpoint path using
the schema `model_type` that defines the model type
"""
# Get the raw checkpoint and config from the local model cards
config = load_local_model_config(model_card_or_path)
if config:
assert "checkpoint" in config, f"Checkpoint missing in {model_card_or_path}"
config_dict = OmegaConf.to_container(config)
assert isinstance(
config_dict, dict
), f"Cannot parse config from {model_card_or_path}"
checkpoint = config_dict.pop("checkpoint")
checkpoint = load_model_checkpoint(checkpoint)
# Get the raw checkpoint and config from the checkpoint path
else:
config_dict = {}
checkpoint = load_model_checkpoint(model_card_or_path)
if "xp.cfg" in checkpoint:
config_dict = {**checkpoint["xp.cfg"], **config_dict} # type: ignore
model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits) # type: ignore
if "model" in checkpoint:
checkpoint = checkpoint["model"]
return checkpoint, model_config
@staticmethod
def parse_config(
config: Dict[str, Any],
config_type: Type[AudioSealT],
nbits: Optional[int] = None,
) -> AudioSealT:
assert "seanet" in config, f"missing seanet backbone config in {config}"
# Patch 1: Resolve the variables in the checkpoint
config = OmegaConf.create(config) # type: ignore
OmegaConf.resolve(config) # type: ignore
config = OmegaConf.to_container(config) # type: ignore
# Patch 2: Put decoder, encoder and detector outside seanet
seanet_config = config["seanet"]
for key_to_patch in ["encoder", "decoder", "detector"]:
if key_to_patch in seanet_config:
config_to_patch = config.get(key_to_patch) or {}
config[key_to_patch] = {
**config_to_patch,
**seanet_config.pop(key_to_patch),
}
config["seanet"] = seanet_config
# Patch 3: Put nbits into config if specified
if nbits and "nbits" not in config:
config["nbits"] = nbits
# remove attributes not related to the model_type
result_config = {}
assert config, f"Empty config"
for field in fields(config_type):
if field.name in config:
result_config[field.name] = config[field.name]
schema = OmegaConf.structured(config_type)
schema.merge_with(result_config)
return schema
@staticmethod
def load_generator(
model_card_or_path: str,
nbits: Optional[int] = None,
) -> AudioSealWM:
"""Load the AudioSeal generator from the model card"""
checkpoint, config = AudioSeal.parse_model(
model_card_or_path,
AudioSealWMConfig,
nbits=nbits,
)
model = create_generator(config)
model.load_state_dict(checkpoint)
return model
@staticmethod
def load_detector(
model_card_or_path: str,
nbits: Optional[int] = None,
) -> AudioSealDetector:
checkpoint, config = AudioSeal.parse_model(
model_card_or_path,
AudioSealDetectorConfig,
nbits=nbits,
)
model = create_detector(config)
model.load_state_dict(checkpoint)
return model