File size: 3,224 Bytes
0209786 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import asdict, dataclass, field, is_dataclass
from typing import Any, Dict, List, Optional
from omegaconf import DictConfig, OmegaConf
from torch import device, dtype
from typing_extensions import TypeAlias
from audioseal.libs import audiocraft
from audioseal.models import AudioSealDetector, AudioSealWM, MsgProcessor
Device: TypeAlias = device
DataType: TypeAlias = dtype
@dataclass
class SEANetConfig:
"""
Map common hparams of SEANet encoder and decoder.
"""
channels: int
dimension: int
n_filters: int
n_residual_layers: int
ratios: List[int]
activation: str
activation_params: Dict[str, float]
norm: str
norm_params: Dict[str, Any]
kernel_size: int
last_kernel_size: int
residual_kernel_size: int
dilation_base: int
causal: bool
pad_mode: str
true_skip: bool
compress: int
lstm: int
disable_norm_outer_blocks: int
@dataclass
class DecoderConfig:
final_activation: Optional[str]
final_activation_params: Optional[dict]
trim_right_ratio: float
@dataclass
class DetectorConfig:
output_dim: int = 32
@dataclass
class AudioSealWMConfig:
nbits: int
seanet: SEANetConfig
decoder: DecoderConfig
@dataclass
class AudioSealDetectorConfig:
nbits: int
seanet: SEANetConfig
detector: DetectorConfig = field(default_factory=lambda: DetectorConfig())
def as_dict(obj: Any) -> Dict[str, Any]:
if isinstance(obj, dict):
return obj
if is_dataclass(obj) and not isinstance(obj, type):
return asdict(obj)
elif isinstance(obj, DictConfig):
return OmegaConf.to_container(obj) # type: ignore
else:
raise NotImplementedError(f"Unsupported type for config: {type(obj)}")
def create_generator(
config: AudioSealWMConfig,
*,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> AudioSealWM:
"""Create a generator from hparams"""
# Currently the encoder hparams are the same as
# SEANet, but this can be changed in the future.
encoder = audiocraft.modules.SEANetEncoder(**as_dict(config.seanet))
encoder = encoder.to(device=device, dtype=dtype)
decoder_config = {**as_dict(config.seanet), **as_dict(config.decoder)}
decoder = audiocraft.modules.SEANetDecoder(**as_dict(decoder_config))
decoder = decoder.to(device=device, dtype=dtype)
msgprocessor = MsgProcessor(nbits=config.nbits, hidden_size=config.seanet.dimension)
msgprocessor = msgprocessor.to(device=device, dtype=dtype)
return AudioSealWM(encoder=encoder, decoder=decoder, msg_processor=msgprocessor)
def create_detector(
config: AudioSealDetectorConfig,
*,
device: Optional[Device] = None,
dtype: Optional[DataType] = None,
) -> AudioSealDetector:
detector_config = {**as_dict(config.seanet), **as_dict(config.detector)}
detector = AudioSealDetector(nbits=config.nbits, **detector_config)
detector = detector.to(device=device, dtype=dtype)
return detector
|