File size: 7,201 Bytes
83940d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import os
from dataclasses import fields
from hashlib import sha1
from pathlib import Path
from typing import (  # type: ignore[attr-defined]
    Any,
    Dict,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)
from urllib.parse import urlparse  # noqa: F401

import torch
from omegaconf import DictConfig, OmegaConf

import audioseal
from audioseal.builder import (
    AudioSealDetectorConfig,
    AudioSealWMConfig,
    create_detector,
    create_generator,
)
from audioseal.models import AudioSealDetector, AudioSealWM

AudioSealT = TypeVar("AudioSealT", AudioSealWMConfig, AudioSealDetectorConfig)


class ModelLoadError(RuntimeError):
    """Raised when the model loading fails"""


def _get_path_from_env(var_name: str) -> Optional[Path]:
    pathname = os.getenv(var_name)
    if not pathname:
        return None

    try:
        return Path(pathname)
    except ValueError as ex:
        raise RuntimeError(f"Expect valid pathname, get '{pathname}'.") from ex


def _get_cache_dir(env_names: List[str]):
    """Re-use cache dir from a list of existing caches"""
    for env in env_names:
        cache_dir = _get_path_from_env(env)
        if cache_dir:
            break
    else:
        cache_dir = Path("~/.cache").expanduser().resolve()

    # Create a sub-dir to not mess up with existing caches
    cache_dir = cache_dir / "audioseal"
    cache_dir.mkdir(exist_ok=True, parents=True)

    return cache_dir


def load_model_checkpoint(
    model_path: Union[Path, str],
    device: Union[str, torch.device] = "cpu",
):
    if Path(model_path).is_file():
        return torch.load(model_path, map_location=device)

    cache_dir = _get_cache_dir(
        ["AUDIOSEAL_CACHE_DIR", "AUDIOCRAFT_CACHE_DIR", "XDG_CACHE_HOME"]
    )
    parts = urlparse(str(model_path))
    if parts.scheme == "https":

        hash_ = sha1(parts.path.encode()).hexdigest()[:24]
        return torch.hub.load_state_dict_from_url(
            str(model_path), model_dir=cache_dir, map_location=device, file_name=hash_
        )
    elif str(model_path).startswith("facebook/audioseal/"):
        hf_filename = str(model_path)[len("facebook/audioseal/") :]

        try:
            from huggingface_hub import hf_hub_download
        except ModuleNotFoundError:
            print(
                f"The model path {model_path} seems to be a direct HF path, "
                "but you do not install Huggingface_hub. Install with for example "
                "`pip install huggingface_hub` to use this feature."
            )
        file = hf_hub_download(
            repo_id="facebook/audioseal",
            filename=hf_filename,
            cache_dir=cache_dir,
            library_name="audioseal",
            library_version=audioseal.__version__,
        )
        return torch.load(file, map_location=device)
    else:
        raise ModelLoadError(f"Path or uri {model_path} is unknown or does not exist")


def load_local_model_config(model_card: str) -> Optional[DictConfig]:
    config_file = Path(__file__).parent / "cards" / (model_card + ".yaml")
    if Path(config_file).is_file():
        return cast(DictConfig, OmegaConf.load(config_file.resolve()))
    else:
        return None


class AudioSeal:

    @staticmethod
    def parse_model(
        model_card_or_path: str,
        model_type: Type[AudioSealT],
        nbits: Optional[int] = None,
    ) -> Tuple[Dict[str, Any], AudioSealT]:
        """
        Parse the information from the model card or checkpoint path using
        the schema `model_type` that defines the model type
        """
        # Get the raw checkpoint and config from the local model cards
        config = load_local_model_config(model_card_or_path)

        if config:
            assert "checkpoint" in config, f"Checkpoint missing in {model_card_or_path}"
            config_dict = OmegaConf.to_container(config)
            assert isinstance(
                config_dict, dict
            ), f"Cannot parse config from {model_card_or_path}"
            checkpoint = config_dict.pop("checkpoint")
            checkpoint = load_model_checkpoint(checkpoint)

        # Get the raw checkpoint and config from the checkpoint path
        else:
            config_dict = {}
            checkpoint = load_model_checkpoint(model_card_or_path)

        if "xp.cfg" in checkpoint:
            config_dict = {**checkpoint["xp.cfg"], **config_dict}  # type: ignore

        model_config = AudioSeal.parse_config(config_dict, config_type=model_type, nbits=nbits)  # type: ignore

        if "model" in checkpoint:
            checkpoint = checkpoint["model"]

        return checkpoint, model_config

    @staticmethod
    def parse_config(
        config: Dict[str, Any],
        config_type: Type[AudioSealT],
        nbits: Optional[int] = None,
    ) -> AudioSealT:

        assert "seanet" in config, f"missing seanet backbone config in {config}"

        # Patch 1: Resolve the variables in the checkpoint
        config = OmegaConf.create(config)  # type: ignore
        OmegaConf.resolve(config)  # type: ignore
        config = OmegaConf.to_container(config)  # type: ignore

        # Patch 2: Put decoder, encoder and detector outside seanet
        seanet_config = config["seanet"]
        for key_to_patch in ["encoder", "decoder", "detector"]:
            if key_to_patch in seanet_config:
                config_to_patch = config.get(key_to_patch) or {}
                config[key_to_patch] = {
                    **config_to_patch,
                    **seanet_config.pop(key_to_patch),
                }

        config["seanet"] = seanet_config

        # Patch 3: Put nbits into config if specified
        if nbits and "nbits" not in config:
            config["nbits"] = nbits

        # remove attributes not related to the model_type
        result_config = {}
        assert config, f"Empty config"
        for field in fields(config_type):
            if field.name in config:
                result_config[field.name] = config[field.name]

        schema = OmegaConf.structured(config_type)
        schema.merge_with(result_config)
        return schema

    @staticmethod
    def load_generator(
        model_card_or_path: str,
        nbits: Optional[int] = None,
    ) -> AudioSealWM:
        """Load the AudioSeal generator from the model card"""
        checkpoint, config = AudioSeal.parse_model(
            model_card_or_path,
            AudioSealWMConfig,
            nbits=nbits,
        )

        model = create_generator(config)
        model.load_state_dict(checkpoint)
        return model

    @staticmethod
    def load_detector(
        model_card_or_path: str,
        nbits: Optional[int] = None,
    ) -> AudioSealDetector:
        checkpoint, config = AudioSeal.parse_model(
            model_card_or_path,
            AudioSealDetectorConfig,
            nbits=nbits,
        )
        model = create_detector(config)
        model.load_state_dict(checkpoint)
        return model