victan commited on
Commit
eca0c8d
·
1 Parent(s): 8e14688

Upload seamless_communication/models/generator/ecapa_tdnn_builder.py with huggingface_hub

Browse files
seamless_communication/models/generator/ecapa_tdnn_builder.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional
9
+
10
+ from fairseq2.models.utils.arch_registry import ArchitectureRegistry
11
+ from fairseq2.typing import DataType, Device
12
+
13
+ from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
14
+
15
+
16
+ @dataclass
17
+ class EcapaTDNNConfig:
18
+ channels: List[int]
19
+ kernel_sizes: List[int]
20
+ dilations: List[int]
21
+ attention_channels: int
22
+ res2net_scale: int
23
+ se_channels: int
24
+ global_context: bool
25
+ groups: List[int]
26
+ embed_dim: int
27
+ input_dim: int
28
+
29
+
30
+ ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
31
+
32
+ ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
33
+
34
+
35
+ @ecapa_tdnn_arch("base")
36
+ def _base_ecapa_tdnn() -> EcapaTDNNConfig:
37
+ return EcapaTDNNConfig(
38
+ channels=[512, 512, 512, 512, 1536],
39
+ kernel_sizes=[5, 3, 3, 3, 1],
40
+ dilations=[1, 2, 3, 4, 1],
41
+ attention_channels=128,
42
+ res2net_scale=8,
43
+ se_channels=128,
44
+ global_context=True,
45
+ groups=[1, 1, 1, 1, 1],
46
+ embed_dim=512,
47
+ input_dim=80,
48
+ )
49
+
50
+
51
+ class EcapaTDNNBuilder:
52
+ """
53
+ Builder module for ECAPA_TDNN model
54
+ """
55
+
56
+ config: EcapaTDNNConfig
57
+ device: Optional[Device]
58
+ dtype: Optional[DataType]
59
+
60
+ def __init__(
61
+ self,
62
+ config: EcapaTDNNConfig,
63
+ *,
64
+ device: Optional[Device] = None,
65
+ dtype: Optional[DataType] = None,
66
+ ) -> None:
67
+ """
68
+ :param config:
69
+ The configuration to use.
70
+ :param devicev:
71
+ The device on which to initialize modules.
72
+ :param dtype:
73
+ The data type of module parameters and buffers.
74
+ """
75
+ self.config = config
76
+
77
+ self.device, self.dtype = device, dtype
78
+
79
+ def build_model(self) -> ECAPA_TDNN:
80
+ """Build a model."""
81
+ model = ECAPA_TDNN(
82
+ self.config.channels,
83
+ self.config.kernel_sizes,
84
+ self.config.dilations,
85
+ self.config.attention_channels,
86
+ self.config.res2net_scale,
87
+ self.config.se_channels,
88
+ self.config.global_context,
89
+ self.config.groups,
90
+ self.config.embed_dim,
91
+ self.config.input_dim,
92
+ )
93
+ model.to(device=self.device, dtype=self.dtype)
94
+ return model
95
+
96
+
97
+ def create_ecapa_tdnn_model(
98
+ config: EcapaTDNNConfig,
99
+ device: Optional[Device] = None,
100
+ dtype: Optional[DataType] = None,
101
+ ) -> ECAPA_TDNN:
102
+ """Create a ECAPA_TDNN model.
103
+
104
+ :param config:
105
+ The configuration to use.
106
+ :param device:
107
+ The device on which to initialize modules.
108
+ :param dtype:
109
+ The data type of module parameters and buffers.
110
+ """
111
+
112
+ return EcapaTDNNBuilder(config, device=device, dtype=dtype).build_model()