File size: 7,322 Bytes
1034391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa0f34
 
 
 
 
 
1034391
 
 
 
 
4aa0f34
 
 
1034391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa0f34
 
 
 
 
 
1034391
 
 
4aa0f34
1034391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aa0f34
1034391
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""Configuration management module for the Dia model.

This module provides comprehensive configuration management for the Dia model,
utilizing Pydantic for validation. It defines configurations for data processing,
model architecture (encoder and decoder), and training settings.

Key components:
- DataConfig: Parameters for data loading and preprocessing.
- EncoderConfig: Architecture details for the encoder module.
- DecoderConfig: Architecture details for the decoder module.
- ModelConfig: Combined model architecture settings.
- TrainingConfig: Training hyperparameters and settings.
- DiaConfig: Master configuration combining all components.
"""

import os
from typing import Annotated

from pydantic import BaseModel, BeforeValidator, Field


class DataConfig(BaseModel, frozen=True):
    """Configuration for data loading and preprocessing.

    Attributes:
        text_length: Maximum length of text sequences (must be multiple of 128).
        audio_length: Maximum length of audio sequences (must be multiple of 128).
        channels: Number of audio channels.
        text_pad_value: Value used for padding text sequences.
        audio_eos_value: Value representing the end of audio sequences.
        audio_bos_value: Value representing the beginning of audio sequences.
        audio_pad_value: Value used for padding audio sequences.
        delay_pattern: List of delay values for each audio channel.
    """

    text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
        Field(gt=0, multiple_of=128)
    )
    audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
        Field(gt=0, multiple_of=128)
    )
    channels: int = Field(default=9, gt=0, multiple_of=1)
    text_pad_value: int = Field(default=0)
    audio_eos_value: int = Field(default=1024)
    audio_pad_value: int = Field(default=1025)
    audio_bos_value: int = Field(default=1026)
    delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
        default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
    )

    def __hash__(self) -> int:
        """Generate a hash based on all fields of the config."""
        return hash(
            (
                self.text_length,
                self.audio_length,
                self.channels,
                self.text_pad_value,
                self.audio_pad_value,
                self.audio_bos_value,
                self.audio_eos_value,
                tuple(self.delay_pattern),
            )
        )


class EncoderConfig(BaseModel, frozen=True):
    """Configuration for the encoder component of the Dia model.

    Attributes:
        n_layer: Number of transformer layers.
        n_embd: Embedding dimension.
        n_hidden: Hidden dimension size in the MLP layers.
        n_head: Number of attention heads.
        head_dim: Dimension per attention head.
    """

    n_layer: int = Field(gt=0)
    n_embd: int = Field(gt=0)
    n_hidden: int = Field(gt=0)
    n_head: int = Field(gt=0)
    head_dim: int = Field(gt=0)


class DecoderConfig(BaseModel, frozen=True):
    """Configuration for the decoder component of the Dia model.

    Attributes:
        n_layer: Number of transformer layers.
        n_embd: Embedding dimension.
        n_hidden: Hidden dimension size in the MLP layers.
        gqa_query_heads: Number of query heads for grouped-query self-attention.
        kv_heads: Number of key/value heads for grouped-query self-attention.
        gqa_head_dim: Dimension per query head for grouped-query self-attention.
        cross_query_heads: Number of query heads for cross-attention.
        cross_head_dim: Dimension per cross-attention head.
    """

    n_layer: int = Field(gt=0)
    n_embd: int = Field(gt=0)
    n_hidden: int = Field(gt=0)
    gqa_query_heads: int = Field(gt=0)
    kv_heads: int = Field(gt=0)
    gqa_head_dim: int = Field(gt=0)
    cross_query_heads: int = Field(gt=0)
    cross_head_dim: int = Field(gt=0)


class ModelConfig(BaseModel, frozen=True):
    """Main configuration container for the Dia model architecture.

    Attributes:
        encoder: Configuration for the encoder component.
        decoder: Configuration for the decoder component.
        src_vocab_size: Size of the source (text) vocabulary.
        tgt_vocab_size: Size of the target (audio code) vocabulary.
        dropout: Dropout probability applied within the model.
        normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
        weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
        rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
        rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
    """

    encoder: EncoderConfig
    decoder: DecoderConfig
    src_vocab_size: int = Field(default=128, gt=0)
    tgt_vocab_size: int = Field(default=1028, gt=0)
    dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
    normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
    weight_dtype: str = Field(default="float32", description="Weight precision")
    rope_min_timescale: int = Field(
        default=1, description="Timescale For global Attention"
    )
    rope_max_timescale: int = Field(
        default=10_000, description="Timescale For global Attention"
    )


class TrainingConfig(BaseModel, frozen=True):
    pass


class DiaConfig(BaseModel, frozen=True):
    """Master configuration for the Dia model.

    Combines all sub-configurations into a single validated object.

    Attributes:
        version: Configuration version string.
        model: Model architecture configuration.
        training: Training process configuration (precision settings).
        data: Data loading and processing configuration.
    """

    version: str = Field(default="1.0")
    model: ModelConfig
    # TODO: remove training. this is just for backwards-compatability
    training: TrainingConfig
    data: DataConfig

    def save(self, path: str) -> None:
        """Save the current configuration instance to a JSON file.

        Ensures the parent directory exists and the file has a .json extension.

        Args:
            path: The target file path to save the configuration.

        Raises:
            ValueError: If the path is not a file with a .json extension.
        """
        os.makedirs(os.path.dirname(path), exist_ok=True)
        config_json = self.model_dump_json(indent=2)
        with open(path, "w") as f:
            f.write(config_json)

    @classmethod
    def load(cls, path: str) -> "DiaConfig | None":
        """Load and validate a Dia configuration from a JSON file.

        Args:
            path: The path to the configuration file.

        Returns:
            A validated DiaConfig instance if the file exists and is valid,
            otherwise None if the file is not found.

        Raises:
            ValueError: If the path does not point to an existing .json file.
            pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
        """
        try:
            with open(path, "r") as f:
                content = f.read()
            return cls.model_validate_json(content)
        except FileNotFoundError:
            return None