amaye15 commited on
Commit
fabf416
·
verified ·
1 Parent(s): 85501fc

Training in progress, epoch 1

Browse files
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "AutoencoderForReconstruction"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_autoencoder.AutoencoderConfig",
8
+ "AutoModel": "modeling_autoencoder.AutoencoderForReconstruction"
9
+ },
10
+ "autoencoder_type": "classic",
11
+ "beta": 1.0,
12
+ "bidirectional": true,
13
+ "dropout_rate": 0.1,
14
+ "flow_coupling_layers": 2,
15
+ "hidden_dims": [
16
+ 64,
17
+ 32
18
+ ],
19
+ "input_dim": 20,
20
+ "latent_dim": 16,
21
+ "learn_inverse_preprocessing": true,
22
+ "model_type": "autoencoder",
23
+ "noise_factor": 0.1,
24
+ "num_layers": 2,
25
+ "preprocessing_hidden_dim": 32,
26
+ "preprocessing_num_layers": 2,
27
+ "preprocessing_type": "robust_scaler",
28
+ "reconstruction_loss": "mse",
29
+ "rnn_type": "lstm",
30
+ "sequence_length": null,
31
+ "teacher_forcing_ratio": 0.5,
32
+ "tie_weights": false,
33
+ "torch_dtype": "float32",
34
+ "transformers_version": "4.55.2",
35
+ "use_batch_norm": true,
36
+ "use_learnable_preprocessing": true
37
+ }
configuration_autoencoder.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Autoencoder configuration for Hugging Face Transformers.
3
+ """
4
+
5
+ from transformers import PretrainedConfig
6
+ from typing import List, Optional
7
+
8
+
9
+ class AutoencoderConfig(PretrainedConfig):
10
+ """
11
+ Configuration class for Autoencoder models.
12
+
13
+ This configuration class stores the configuration of an autoencoder model. It is used to instantiate
14
+ an autoencoder model according to the specified arguments, defining the model architecture.
15
+
16
+ Args:
17
+ input_dim (int, optional): Dimensionality of the input data. Defaults to 784.
18
+ hidden_dims (List[int], optional): List of hidden layer dimensions for the encoder.
19
+ The decoder will use the reverse of this list. Defaults to [512, 256, 128].
20
+ latent_dim (int, optional): Dimensionality of the latent space. Defaults to 64.
21
+ activation (str, optional): Activation function to use. Options: "relu", "tanh", "sigmoid",
22
+ "leaky_relu", "gelu", "swish", "silu", "elu", "prelu", "relu6", "hardtanh",
23
+ "hardsigmoid", "hardswish", "mish", "softplus", "softsign", "tanhshrink", "threshold".
24
+ Defaults to "relu".
25
+ dropout_rate (float, optional): Dropout rate for regularization. Defaults to 0.1.
26
+ use_batch_norm (bool, optional): Whether to use batch normalization. Defaults to True.
27
+ tie_weights (bool, optional): Whether to tie encoder and decoder weights. Defaults to False.
28
+ reconstruction_loss (str, optional): Type of reconstruction loss. Options: "mse", "bce", "l1",
29
+ "huber", "smooth_l1", "kl_div", "cosine", "focal", "dice", "tversky", "ssim", "perceptual".
30
+ Defaults to "mse".
31
+ autoencoder_type (str, optional): Type of autoencoder architecture. Options: "classic",
32
+ "variational", "beta_vae", "denoising", "sparse", "contractive", "recurrent". Defaults to "classic".
33
+ beta (float, optional): Beta parameter for beta-VAE. Defaults to 1.0.
34
+ temperature (float, optional): Temperature parameter for Gumbel softmax or other operations. Defaults to 1.0.
35
+ noise_factor (float, optional): Noise factor for denoising autoencoders. Defaults to 0.1.
36
+ rnn_type (str, optional): Type of RNN cell for recurrent autoencoders. Options: "lstm", "gru", "rnn".
37
+ Defaults to "lstm".
38
+ num_layers (int, optional): Number of RNN layers for recurrent autoencoders. Defaults to 2.
39
+ bidirectional (bool, optional): Whether to use bidirectional RNN for encoding. Defaults to True.
40
+ sequence_length (int, optional): Fixed sequence length. If None, supports variable length sequences.
41
+ Defaults to None.
42
+ teacher_forcing_ratio (float, optional): Ratio of teacher forcing during training for recurrent decoders.
43
+ Defaults to 0.5.
44
+ use_learnable_preprocessing (bool, optional): Whether to use learnable preprocessing. Defaults to False.
45
+ preprocessing_type (str, optional): Type of learnable preprocessing. Options: "none", "neural_scaler",
46
+ "normalizing_flow", "minmax_scaler", "robust_scaler", "yeo_johnson". Defaults to "none".
47
+ preprocessing_hidden_dim (int, optional): Hidden dimension for preprocessing networks. Defaults to 64.
48
+ preprocessing_num_layers (int, optional): Number of layers in preprocessing networks. Defaults to 2.
49
+ learn_inverse_preprocessing (bool, optional): Whether to learn inverse preprocessing for reconstruction.
50
+ Defaults to True.
51
+ flow_coupling_layers (int, optional): Number of coupling layers for normalizing flows. Defaults to 4.
52
+ **kwargs: Additional keyword arguments passed to the parent class.
53
+ """
54
+
55
+ model_type = "autoencoder"
56
+
57
+ def __init__(
58
+ self,
59
+ input_dim: int = 784,
60
+ hidden_dims: List[int] = None,
61
+ latent_dim: int = 64,
62
+ activation: str = "relu",
63
+ dropout_rate: float = 0.1,
64
+ use_batch_norm: bool = True,
65
+ tie_weights: bool = False,
66
+ reconstruction_loss: str = "mse",
67
+ autoencoder_type: str = "classic",
68
+ beta: float = 1.0,
69
+ temperature: float = 1.0,
70
+ noise_factor: float = 0.1,
71
+ # Recurrent autoencoder parameters
72
+ rnn_type: str = "lstm",
73
+ num_layers: int = 2,
74
+ bidirectional: bool = True,
75
+ sequence_length: Optional[int] = None,
76
+ teacher_forcing_ratio: float = 0.5,
77
+ # Deep learning preprocessing parameters
78
+ use_learnable_preprocessing: bool = False,
79
+ preprocessing_type: str = "none",
80
+ preprocessing_hidden_dim: int = 64,
81
+ preprocessing_num_layers: int = 2,
82
+ learn_inverse_preprocessing: bool = True,
83
+ flow_coupling_layers: int = 4,
84
+ **kwargs,
85
+ ):
86
+ # Validate parameters
87
+ if hidden_dims is None:
88
+ hidden_dims = [512, 256, 128]
89
+
90
+ # Extended activation functions
91
+ valid_activations = [
92
+ "relu", "tanh", "sigmoid", "leaky_relu", "gelu", "swish", "silu",
93
+ "elu", "prelu", "relu6", "hardtanh", "hardsigmoid", "hardswish",
94
+ "mish", "softplus", "softsign", "tanhshrink", "threshold"
95
+ ]
96
+ if activation not in valid_activations:
97
+ raise ValueError(
98
+ f"`activation` must be one of {valid_activations}, got {activation}."
99
+ )
100
+
101
+ # Extended loss functions
102
+ valid_losses = [
103
+ "mse", "bce", "l1", "huber", "smooth_l1", "kl_div", "cosine",
104
+ "focal", "dice", "tversky", "ssim", "perceptual"
105
+ ]
106
+ if reconstruction_loss not in valid_losses:
107
+ raise ValueError(
108
+ f"`reconstruction_loss` must be one of {valid_losses}, got {reconstruction_loss}."
109
+ )
110
+
111
+ # Autoencoder types
112
+ valid_types = ["classic", "variational", "beta_vae", "denoising", "sparse", "contractive", "recurrent"]
113
+ if autoencoder_type not in valid_types:
114
+ raise ValueError(
115
+ f"`autoencoder_type` must be one of {valid_types}, got {autoencoder_type}."
116
+ )
117
+
118
+ # RNN types for recurrent autoencoders
119
+ valid_rnn_types = ["lstm", "gru", "rnn"]
120
+ if rnn_type not in valid_rnn_types:
121
+ raise ValueError(
122
+ f"`rnn_type` must be one of {valid_rnn_types}, got {rnn_type}."
123
+ )
124
+
125
+ if not (0.0 <= dropout_rate <= 1.0):
126
+ raise ValueError(f"`dropout_rate` must be between 0.0 and 1.0, got {dropout_rate}.")
127
+
128
+ if input_dim <= 0:
129
+ raise ValueError(f"`input_dim` must be positive, got {input_dim}.")
130
+
131
+ if latent_dim <= 0:
132
+ raise ValueError(f"`latent_dim` must be positive, got {latent_dim}.")
133
+
134
+ if not all(dim > 0 for dim in hidden_dims):
135
+ raise ValueError("All dimensions in `hidden_dims` must be positive.")
136
+
137
+ if beta <= 0:
138
+ raise ValueError(f"`beta` must be positive, got {beta}.")
139
+
140
+ if num_layers <= 0:
141
+ raise ValueError(f"`num_layers` must be positive, got {num_layers}.")
142
+
143
+ if not (0.0 <= teacher_forcing_ratio <= 1.0):
144
+ raise ValueError(f"`teacher_forcing_ratio` must be between 0.0 and 1.0, got {teacher_forcing_ratio}.")
145
+
146
+ if sequence_length is not None and sequence_length <= 0:
147
+ raise ValueError(f"`sequence_length` must be positive when specified, got {sequence_length}.")
148
+
149
+ # Preprocessing validation
150
+ valid_preprocessing = [
151
+ "none",
152
+ "neural_scaler",
153
+ "normalizing_flow",
154
+ "minmax_scaler",
155
+ "robust_scaler",
156
+ "yeo_johnson",
157
+ ]
158
+ if preprocessing_type not in valid_preprocessing:
159
+ raise ValueError(
160
+ f"`preprocessing_type` must be one of {valid_preprocessing}, got {preprocessing_type}."
161
+ )
162
+
163
+ if preprocessing_hidden_dim <= 0:
164
+ raise ValueError(f"`preprocessing_hidden_dim` must be positive, got {preprocessing_hidden_dim}.")
165
+
166
+ if preprocessing_num_layers <= 0:
167
+ raise ValueError(f"`preprocessing_num_layers` must be positive, got {preprocessing_num_layers}.")
168
+
169
+ if flow_coupling_layers <= 0:
170
+ raise ValueError(f"`flow_coupling_layers` must be positive, got {flow_coupling_layers}.")
171
+
172
+ # Set configuration attributes
173
+ self.input_dim = input_dim
174
+ self.hidden_dims = hidden_dims
175
+ self.latent_dim = latent_dim
176
+ self.activation = activation
177
+ self.dropout_rate = dropout_rate
178
+ self.use_batch_norm = use_batch_norm
179
+ self.tie_weights = tie_weights
180
+ self.reconstruction_loss = reconstruction_loss
181
+ self.autoencoder_type = autoencoder_type
182
+ self.beta = beta
183
+ self.temperature = temperature
184
+ self.noise_factor = noise_factor
185
+ self.rnn_type = rnn_type
186
+ self.num_layers = num_layers
187
+ self.bidirectional = bidirectional
188
+ self.sequence_length = sequence_length
189
+ self.teacher_forcing_ratio = teacher_forcing_ratio
190
+ self.use_learnable_preprocessing = use_learnable_preprocessing
191
+ self.preprocessing_type = preprocessing_type
192
+ self.preprocessing_hidden_dim = preprocessing_hidden_dim
193
+ self.preprocessing_num_layers = preprocessing_num_layers
194
+ self.learn_inverse_preprocessing = learn_inverse_preprocessing
195
+ self.flow_coupling_layers = flow_coupling_layers
196
+
197
+ # Call parent constructor
198
+ super().__init__(**kwargs)
199
+
200
+ @property
201
+ def decoder_dims(self) -> List[int]:
202
+ """Get decoder dimensions (reverse of encoder hidden dims)."""
203
+ return list(reversed(self.hidden_dims))
204
+
205
+ @property
206
+ def is_variational(self) -> bool:
207
+ """Check if this is a variational autoencoder."""
208
+ return self.autoencoder_type in ["variational", "beta_vae"]
209
+
210
+ @property
211
+ def is_denoising(self) -> bool:
212
+ """Check if this is a denoising autoencoder."""
213
+ return self.autoencoder_type == "denoising"
214
+
215
+ @property
216
+ def is_sparse(self) -> bool:
217
+ """Check if this is a sparse autoencoder."""
218
+ return self.autoencoder_type == "sparse"
219
+
220
+ @property
221
+ def is_contractive(self) -> bool:
222
+ """Check if this is a contractive autoencoder."""
223
+ return self.autoencoder_type == "contractive"
224
+
225
+ @property
226
+ def is_recurrent(self) -> bool:
227
+ """Check if this is a recurrent autoencoder."""
228
+ return self.autoencoder_type == "recurrent"
229
+
230
+ @property
231
+ def rnn_hidden_size(self) -> int:
232
+ """Get the RNN hidden size (same as latent_dim for recurrent AE)."""
233
+ return self.latent_dim
234
+
235
+ @property
236
+ def rnn_output_size(self) -> int:
237
+ """Get the RNN output size considering bidirectionality."""
238
+ return self.latent_dim * (2 if self.bidirectional else 1)
239
+
240
+ @property
241
+ def has_preprocessing(self) -> bool:
242
+ """Check if learnable preprocessing is enabled."""
243
+ return self.use_learnable_preprocessing and self.preprocessing_type != "none"
244
+
245
+ @property
246
+ def is_neural_scaler(self) -> bool:
247
+ """Check if using neural scaler preprocessing."""
248
+ return self.preprocessing_type == "neural_scaler"
249
+
250
+ @property
251
+ def is_normalizing_flow(self) -> bool:
252
+ """Check if using normalizing flow preprocessing."""
253
+ return self.preprocessing_type == "normalizing_flow"
254
+
255
+ @property
256
+ def is_minmax_scaler(self) -> bool:
257
+ """Check if using learnable MinMax scaler preprocessing."""
258
+ return self.preprocessing_type == "minmax_scaler"
259
+
260
+ @property
261
+ def is_robust_scaler(self) -> bool:
262
+ """Check if using learnable Robust scaler preprocessing."""
263
+ return self.preprocessing_type == "robust_scaler"
264
+
265
+ @property
266
+ def is_yeo_johnson(self) -> bool:
267
+ """Check if using learnable Yeo-Johnson power transform preprocessing."""
268
+ return self.preprocessing_type == "yeo_johnson"
269
+
270
+ def to_dict(self):
271
+ """
272
+ Serializes this instance to a Python dictionary.
273
+ """
274
+ output = super().to_dict()
275
+ return output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0ee24472ccb835430c31a59e88747df793c96f455940d4f308dd894fd765dcd
3
+ size 59368
modeling_autoencoder.py ADDED
@@ -0,0 +1,1437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PyTorch Autoencoder model for Hugging Face Transformers.
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from typing import Optional, Tuple, Union, Dict, Any, List
9
+ from dataclasses import dataclass
10
+ import random
11
+
12
+ from transformers import PreTrainedModel
13
+ from transformers.modeling_outputs import BaseModelOutput
14
+ from transformers.utils import ModelOutput
15
+
16
+ from configuration_autoencoder import AutoencoderConfig
17
+
18
+
19
+ class NeuralScaler(nn.Module):
20
+ """Learnable alternative to StandardScaler using neural networks."""
21
+
22
+ def __init__(self, config: AutoencoderConfig):
23
+ super().__init__()
24
+ self.config = config
25
+ input_dim = config.input_dim
26
+ hidden_dim = config.preprocessing_hidden_dim
27
+
28
+ # Networks to learn data-dependent statistics
29
+ self.mean_estimator = nn.Sequential(
30
+ nn.Linear(input_dim, hidden_dim),
31
+ nn.ReLU(),
32
+ nn.Linear(hidden_dim, hidden_dim),
33
+ nn.ReLU(),
34
+ nn.Linear(hidden_dim, input_dim)
35
+ )
36
+
37
+ self.std_estimator = nn.Sequential(
38
+ nn.Linear(input_dim, hidden_dim),
39
+ nn.ReLU(),
40
+ nn.Linear(hidden_dim, hidden_dim),
41
+ nn.ReLU(),
42
+ nn.Linear(hidden_dim, input_dim),
43
+ nn.Softplus() # Ensure positive standard deviation
44
+ )
45
+
46
+ # Learnable affine transformation parameters
47
+ self.weight = nn.Parameter(torch.ones(input_dim))
48
+ self.bias = nn.Parameter(torch.zeros(input_dim))
49
+
50
+ # Running statistics for inference (like BatchNorm)
51
+ self.register_buffer('running_mean', torch.zeros(input_dim))
52
+ self.register_buffer('running_std', torch.ones(input_dim))
53
+ self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
54
+
55
+ # Momentum for running statistics
56
+ self.momentum = 0.1
57
+
58
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ """
60
+ Forward pass through neural scaler.
61
+
62
+ Args:
63
+ x: Input tensor (2D or 3D)
64
+ inverse: Whether to apply inverse transformation
65
+
66
+ Returns:
67
+ Tuple of (transformed_tensor, regularization_loss)
68
+ """
69
+ if inverse:
70
+ return self._inverse_transform(x)
71
+
72
+ # Handle both 2D and 3D tensors
73
+ original_shape = x.shape
74
+ if x.dim() == 3:
75
+ # Reshape (batch, seq, features) -> (batch*seq, features)
76
+ x = x.view(-1, x.size(-1))
77
+
78
+ if self.training:
79
+ # Training mode: learn statistics from current batch
80
+ batch_mean = x.mean(dim=0, keepdim=True)
81
+ batch_std = x.std(dim=0, keepdim=True)
82
+
83
+ # Learn data-dependent adjustments
84
+ learned_mean_adj = self.mean_estimator(batch_mean)
85
+ learned_std_adj = self.std_estimator(batch_std)
86
+
87
+ # Combine batch statistics with learned adjustments
88
+ effective_mean = batch_mean + learned_mean_adj
89
+ effective_std = batch_std + learned_std_adj + 1e-8
90
+
91
+ # Update running statistics
92
+ with torch.no_grad():
93
+ self.num_batches_tracked += 1
94
+ if self.num_batches_tracked == 1:
95
+ self.running_mean.copy_(batch_mean.squeeze())
96
+ self.running_std.copy_(batch_std.squeeze())
97
+ else:
98
+ self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
99
+ self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
100
+ else:
101
+ # Inference mode: use running statistics
102
+ effective_mean = self.running_mean.unsqueeze(0)
103
+ effective_std = self.running_std.unsqueeze(0) + 1e-8
104
+
105
+ # Normalize
106
+ normalized = (x - effective_mean) / effective_std
107
+
108
+ # Apply learnable affine transformation
109
+ transformed = normalized * self.weight + self.bias
110
+
111
+ # Reshape back to original shape if needed
112
+ if len(original_shape) == 3:
113
+ transformed = transformed.view(original_shape)
114
+
115
+ # Regularization loss to encourage meaningful learning
116
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
117
+
118
+ return transformed, reg_loss
119
+
120
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
121
+ """Apply inverse transformation to get back original scale."""
122
+ if not self.config.learn_inverse_preprocessing:
123
+ return x, torch.tensor(0.0, device=x.device)
124
+
125
+ # Handle both 2D and 3D tensors
126
+ original_shape = x.shape
127
+ if x.dim() == 3:
128
+ # Reshape (batch, seq, features) -> (batch*seq, features)
129
+ x = x.view(-1, x.size(-1))
130
+
131
+ # Reverse affine transformation
132
+ x = (x - self.bias) / (self.weight + 1e-8)
133
+
134
+ # Reverse normalization using running statistics
135
+ effective_mean = self.running_mean.unsqueeze(0)
136
+ effective_std = self.running_std.unsqueeze(0) + 1e-8
137
+ x = x * effective_std + effective_mean
138
+
139
+ # Reshape back to original shape if needed
140
+ if len(original_shape) == 3:
141
+ x = x.view(original_shape)
142
+
143
+ return x, torch.tensor(0.0, device=x.device)
144
+
145
+
146
+
147
+ class LearnableMinMaxScaler(nn.Module):
148
+ """Learnable MinMax scaler that adapts bounds during training.
149
+
150
+ Scales features to [0, 1] using batch min/range with learnable adjustments and
151
+ a learnable affine transform. Supports 2D (B, F) and 3D (B, T, F) inputs.
152
+ """
153
+
154
+ def __init__(self, config: AutoencoderConfig):
155
+ super().__init__()
156
+ self.config = config
157
+ input_dim = config.input_dim
158
+ hidden_dim = config.preprocessing_hidden_dim
159
+
160
+ # Networks to learn adjustments to batch min and range
161
+ self.min_estimator = nn.Sequential(
162
+ nn.Linear(input_dim, hidden_dim),
163
+ nn.ReLU(),
164
+ nn.Linear(hidden_dim, hidden_dim),
165
+ nn.ReLU(),
166
+ nn.Linear(hidden_dim, input_dim),
167
+ )
168
+ self.range_estimator = nn.Sequential(
169
+ nn.Linear(input_dim, hidden_dim),
170
+ nn.ReLU(),
171
+ nn.Linear(hidden_dim, hidden_dim),
172
+ nn.ReLU(),
173
+ nn.Linear(hidden_dim, input_dim),
174
+ nn.Softplus(), # Ensure positive adjustment to range
175
+ )
176
+
177
+ # Learnable affine transformation parameters
178
+ self.weight = nn.Parameter(torch.ones(input_dim))
179
+ self.bias = nn.Parameter(torch.zeros(input_dim))
180
+
181
+ # Running statistics for inference
182
+ self.register_buffer("running_min", torch.zeros(input_dim))
183
+ self.register_buffer("running_range", torch.ones(input_dim))
184
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
185
+
186
+ self.momentum = 0.1
187
+
188
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
189
+ if inverse:
190
+ return self._inverse_transform(x)
191
+
192
+ original_shape = x.shape
193
+ if x.dim() == 3:
194
+ x = x.view(-1, x.size(-1))
195
+
196
+ eps = 1e-8
197
+ if self.training:
198
+ batch_min = x.min(dim=0, keepdim=True).values
199
+ batch_max = x.max(dim=0, keepdim=True).values
200
+ batch_range = (batch_max - batch_min).clamp_min(eps)
201
+
202
+ # Learn adjustments
203
+ learned_min_adj = self.min_estimator(batch_min)
204
+ learned_range_adj = self.range_estimator(batch_range)
205
+
206
+ effective_min = batch_min + learned_min_adj
207
+ effective_range = batch_range + learned_range_adj + eps
208
+
209
+ # Update running stats with raw batch min/range for stable inversion
210
+ with torch.no_grad():
211
+ self.num_batches_tracked += 1
212
+ if self.num_batches_tracked == 1:
213
+ self.running_min.copy_(batch_min.squeeze())
214
+ self.running_range.copy_(batch_range.squeeze())
215
+ else:
216
+ self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum)
217
+ self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum)
218
+ else:
219
+ effective_min = self.running_min.unsqueeze(0)
220
+ effective_range = self.running_range.unsqueeze(0)
221
+
222
+ # Scale to [0, 1]
223
+ scaled = (x - effective_min) / effective_range
224
+
225
+ # Learnable affine transform
226
+ transformed = scaled * self.weight + self.bias
227
+
228
+ if len(original_shape) == 3:
229
+ transformed = transformed.view(original_shape)
230
+
231
+ # Regularization: encourage non-degenerate range and modest affine params
232
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
233
+ if self.training:
234
+ reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean()
235
+
236
+ return transformed, reg_loss
237
+
238
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
239
+ if not self.config.learn_inverse_preprocessing:
240
+ return x, torch.tensor(0.0, device=x.device)
241
+
242
+ original_shape = x.shape
243
+ if x.dim() == 3:
244
+ x = x.view(-1, x.size(-1))
245
+
246
+ # Reverse affine
247
+ x = (x - self.bias) / (self.weight + 1e-8)
248
+ # Reverse MinMax using running stats
249
+ x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0)
250
+
251
+ if len(original_shape) == 3:
252
+ x = x.view(original_shape)
253
+
254
+ return x, torch.tensor(0.0, device=x.device)
255
+
256
+
257
+ class LearnableRobustScaler(nn.Module):
258
+ """Learnable Robust scaler using median and IQR with learnable adjustments.
259
+
260
+ Normalizes as (x - median) / IQR with learnable adjustments and an affine head.
261
+ Supports 2D (B, F) and 3D (B, T, F) inputs.
262
+ """
263
+
264
+ def __init__(self, config: AutoencoderConfig):
265
+ super().__init__()
266
+ self.config = config
267
+ input_dim = config.input_dim
268
+ hidden_dim = config.preprocessing_hidden_dim
269
+
270
+ self.median_estimator = nn.Sequential(
271
+ nn.Linear(input_dim, hidden_dim),
272
+ nn.ReLU(),
273
+ nn.Linear(hidden_dim, hidden_dim),
274
+ nn.ReLU(),
275
+ nn.Linear(hidden_dim, input_dim),
276
+ )
277
+ self.iqr_estimator = nn.Sequential(
278
+ nn.Linear(input_dim, hidden_dim),
279
+ nn.ReLU(),
280
+ nn.Linear(hidden_dim, hidden_dim),
281
+ nn.ReLU(),
282
+ nn.Linear(hidden_dim, input_dim),
283
+ nn.Softplus(), # Ensure positive IQR adjustment
284
+ )
285
+
286
+ self.weight = nn.Parameter(torch.ones(input_dim))
287
+ self.bias = nn.Parameter(torch.zeros(input_dim))
288
+
289
+ self.register_buffer("running_median", torch.zeros(input_dim))
290
+ self.register_buffer("running_iqr", torch.ones(input_dim))
291
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
292
+
293
+ self.momentum = 0.1
294
+
295
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
296
+ if inverse:
297
+ return self._inverse_transform(x)
298
+
299
+ original_shape = x.shape
300
+ if x.dim() == 3:
301
+ x = x.view(-1, x.size(-1))
302
+
303
+ eps = 1e-8
304
+ if self.training:
305
+ qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0)
306
+ q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :]
307
+ iqr = (q75 - q25).clamp_min(eps)
308
+
309
+ learned_med_adj = self.median_estimator(med)
310
+ learned_iqr_adj = self.iqr_estimator(iqr)
311
+
312
+ effective_median = med + learned_med_adj
313
+ effective_iqr = iqr + learned_iqr_adj + eps
314
+
315
+ with torch.no_grad():
316
+ self.num_batches_tracked += 1
317
+ if self.num_batches_tracked == 1:
318
+ self.running_median.copy_(med.squeeze())
319
+ self.running_iqr.copy_(iqr.squeeze())
320
+ else:
321
+ self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum)
322
+ self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum)
323
+ else:
324
+ effective_median = self.running_median.unsqueeze(0)
325
+ effective_iqr = self.running_iqr.unsqueeze(0)
326
+
327
+ normalized = (x - effective_median) / effective_iqr
328
+ transformed = normalized * self.weight + self.bias
329
+
330
+ if len(original_shape) == 3:
331
+ transformed = transformed.view(original_shape)
332
+
333
+ reg_loss = 0.01 * (self.weight.var() + self.bias.var())
334
+ if self.training:
335
+ reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean()
336
+
337
+ return transformed, reg_loss
338
+
339
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
340
+ if not self.config.learn_inverse_preprocessing:
341
+ return x, torch.tensor(0.0, device=x.device)
342
+
343
+ original_shape = x.shape
344
+ if x.dim() == 3:
345
+ x = x.view(-1, x.size(-1))
346
+
347
+ x = (x - self.bias) / (self.weight + 1e-8)
348
+ x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0)
349
+
350
+ if len(original_shape) == 3:
351
+ x = x.view(original_shape)
352
+
353
+ return x, torch.tensor(0.0, device=x.device)
354
+
355
+
356
+ class LearnableYeoJohnsonPreprocessor(nn.Module):
357
+ """Learnable Yeo-Johnson power transform with per-feature λ and affine head.
358
+
359
+ Applies Yeo-Johnson transform elementwise with learnable lambda per feature,
360
+ followed by standardization and a learnable affine transform. Supports 2D and 3D inputs.
361
+ """
362
+
363
+ def __init__(self, config: AutoencoderConfig):
364
+ super().__init__()
365
+ self.config = config
366
+ input_dim = config.input_dim
367
+
368
+ # Learnable lambda per feature (unconstrained). Initialize around 1.0
369
+ self.lmbda = nn.Parameter(torch.ones(input_dim))
370
+
371
+ # Learnable affine parameters after standardization
372
+ self.weight = nn.Parameter(torch.ones(input_dim))
373
+ self.bias = nn.Parameter(torch.zeros(input_dim))
374
+
375
+ # Running stats for transformed data
376
+ self.register_buffer("running_mean", torch.zeros(input_dim))
377
+ self.register_buffer("running_std", torch.ones(input_dim))
378
+ self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long))
379
+ self.momentum = 0.1
380
+
381
+ def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
382
+ eps = 1e-6
383
+ lmbda = lmbda.unsqueeze(0) # broadcast over batch
384
+ pos = x >= 0
385
+ # For x >= 0
386
+ if_part = torch.where(
387
+ torch.abs(lmbda) > eps,
388
+ ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda,
389
+ torch.log((x + 1.0).clamp_min(eps)),
390
+ )
391
+ # For x < 0
392
+ two_minus_lambda = 2.0 - lmbda
393
+ else_part = torch.where(
394
+ torch.abs(two_minus_lambda) > eps,
395
+ -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda,
396
+ -torch.log((1.0 - x).clamp_min(eps)),
397
+ )
398
+ return torch.where(pos, if_part, else_part)
399
+
400
+ def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor:
401
+ eps = 1e-6
402
+ lmbda = lmbda.unsqueeze(0)
403
+ pos = y >= 0
404
+ # Inverse for y >= 0
405
+ x_pos = torch.where(
406
+ torch.abs(lmbda) > eps,
407
+ (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0,
408
+ torch.exp(y) - 1.0,
409
+ )
410
+ # Inverse for y < 0
411
+ two_minus_lambda = 2.0 - lmbda
412
+ x_neg = torch.where(
413
+ torch.abs(two_minus_lambda) > eps,
414
+ 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda),
415
+ 1.0 - torch.exp(-y),
416
+ )
417
+ return torch.where(pos, x_pos, x_neg)
418
+
419
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
420
+ if inverse:
421
+ return self._inverse_transform(x)
422
+
423
+ orig_shape = x.shape
424
+ if x.dim() == 3:
425
+ x = x.view(-1, x.size(-1))
426
+
427
+ # Apply Yeo-Johnson
428
+ y = self._yeo_johnson(x, self.lmbda)
429
+
430
+ # Batch stats and running stats on transformed data
431
+ if self.training:
432
+ batch_mean = y.mean(dim=0, keepdim=True)
433
+ batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6)
434
+ with torch.no_grad():
435
+ self.num_batches_tracked += 1
436
+ if self.num_batches_tracked == 1:
437
+ self.running_mean.copy_(batch_mean.squeeze())
438
+ self.running_std.copy_(batch_std.squeeze())
439
+ else:
440
+ self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum)
441
+ self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum)
442
+ mean = batch_mean
443
+ std = batch_std
444
+ else:
445
+ mean = self.running_mean.unsqueeze(0)
446
+ std = self.running_std.unsqueeze(0)
447
+
448
+ y_norm = (y - mean) / std
449
+ out = y_norm * self.weight + self.bias
450
+
451
+ if len(orig_shape) == 3:
452
+ out = out.view(orig_shape)
453
+
454
+ # Regularize lambda to avoid extreme values; encourage identity around 1
455
+ reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var())
456
+ return out, reg
457
+
458
+ def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
459
+ if not self.config.learn_inverse_preprocessing:
460
+ return x, torch.tensor(0.0, device=x.device)
461
+
462
+ orig_shape = x.shape
463
+ if x.dim() == 3:
464
+ x = x.view(-1, x.size(-1))
465
+
466
+ # Reverse affine and normalization with running stats
467
+ y = (x - self.bias) / (self.weight + 1e-8)
468
+ y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0)
469
+
470
+ # Inverse Yeo-Johnson
471
+ out = self._yeo_johnson_inverse(y, self.lmbda)
472
+
473
+ if len(orig_shape) == 3:
474
+ out = out.view(orig_shape)
475
+
476
+ return out, torch.tensor(0.0, device=x.device)
477
+
478
+ class CouplingLayer(nn.Module):
479
+ """Coupling layer for normalizing flows."""
480
+
481
+ def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"):
482
+ super().__init__()
483
+ self.input_dim = input_dim
484
+ self.hidden_dim = hidden_dim
485
+
486
+ # Create mask for coupling
487
+ if mask_type == "alternating":
488
+ self.register_buffer('mask', torch.arange(input_dim) % 2)
489
+ elif mask_type == "half":
490
+ mask = torch.zeros(input_dim)
491
+ mask[:input_dim // 2] = 1
492
+ self.register_buffer('mask', mask)
493
+ else:
494
+ raise ValueError(f"Unknown mask type: {mask_type}")
495
+
496
+ # Scale and translation networks
497
+ masked_dim = int(self.mask.sum().item())
498
+ unmasked_dim = input_dim - masked_dim
499
+
500
+ self.scale_net = nn.Sequential(
501
+ nn.Linear(masked_dim, hidden_dim),
502
+ nn.ReLU(),
503
+ nn.Linear(hidden_dim, hidden_dim),
504
+ nn.ReLU(),
505
+ nn.Linear(hidden_dim, unmasked_dim),
506
+ nn.Tanh() # Bounded output for stability
507
+ )
508
+
509
+ self.translate_net = nn.Sequential(
510
+ nn.Linear(masked_dim, hidden_dim),
511
+ nn.ReLU(),
512
+ nn.Linear(hidden_dim, hidden_dim),
513
+ nn.ReLU(),
514
+ nn.Linear(hidden_dim, unmasked_dim)
515
+ )
516
+
517
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
518
+ """
519
+ Forward pass through coupling layer.
520
+
521
+ Args:
522
+ x: Input tensor
523
+ inverse: Whether to apply inverse transformation
524
+
525
+ Returns:
526
+ Tuple of (transformed_tensor, log_determinant)
527
+ """
528
+ mask = self.mask.bool()
529
+ x_masked = x[:, mask]
530
+ x_unmasked = x[:, ~mask]
531
+
532
+ # Compute scale and translation
533
+ s = self.scale_net(x_masked)
534
+ t = self.translate_net(x_masked)
535
+
536
+ if not inverse:
537
+ # Forward transformation
538
+ y_unmasked = x_unmasked * torch.exp(s) + t
539
+ log_det = s.sum(dim=1)
540
+ else:
541
+ # Inverse transformation
542
+ y_unmasked = (x_unmasked - t) * torch.exp(-s)
543
+ log_det = -s.sum(dim=1)
544
+
545
+ # Reconstruct output
546
+ y = torch.zeros_like(x)
547
+ y[:, mask] = x_masked
548
+ y[:, ~mask] = y_unmasked
549
+
550
+ return y, log_det
551
+
552
+
553
+ class NormalizingFlowPreprocessor(nn.Module):
554
+ """Normalizing flow for learnable data preprocessing."""
555
+
556
+ def __init__(self, config: AutoencoderConfig):
557
+ super().__init__()
558
+ self.config = config
559
+ input_dim = config.input_dim
560
+ hidden_dim = config.preprocessing_hidden_dim
561
+ num_layers = config.flow_coupling_layers
562
+
563
+ # Create coupling layers with alternating masks
564
+ self.layers = nn.ModuleList()
565
+ for i in range(num_layers):
566
+ mask_type = "alternating" if i % 2 == 0 else "half"
567
+ self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type))
568
+
569
+ # Optional: Add batch normalization between layers
570
+ if config.use_batch_norm:
571
+ self.batch_norms = nn.ModuleList([
572
+ nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1)
573
+ ])
574
+ else:
575
+ self.batch_norms = None
576
+
577
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
578
+ """
579
+ Forward pass through normalizing flow.
580
+
581
+ Args:
582
+ x: Input tensor (2D or 3D)
583
+ inverse: Whether to apply inverse transformation
584
+
585
+ Returns:
586
+ Tuple of (transformed_tensor, total_log_determinant)
587
+ """
588
+ # Handle both 2D and 3D tensors
589
+ original_shape = x.shape
590
+ if x.dim() == 3:
591
+ # Reshape (batch, seq, features) -> (batch*seq, features)
592
+ x = x.view(-1, x.size(-1))
593
+
594
+ log_det_total = torch.zeros(x.size(0), device=x.device)
595
+
596
+ if not inverse:
597
+ # Forward pass
598
+ for i, layer in enumerate(self.layers):
599
+ x, log_det = layer(x, inverse=False)
600
+ log_det_total += log_det
601
+
602
+ # Apply batch normalization (except for last layer)
603
+ if self.batch_norms and i < len(self.layers) - 1:
604
+ x = self.batch_norms[i](x)
605
+ else:
606
+ # Inverse pass
607
+ for i, layer in enumerate(reversed(self.layers)):
608
+ # Reverse batch normalization (except for first layer in reverse)
609
+ if self.batch_norms and i > 0:
610
+ # Note: This is approximate inverse of batch norm
611
+ bn_idx = len(self.layers) - 1 - i
612
+ x = self.batch_norms[bn_idx](x)
613
+
614
+ x, log_det = layer(x, inverse=True)
615
+ log_det_total += log_det
616
+
617
+ # Reshape back to original shape if needed
618
+ if len(original_shape) == 3:
619
+ x = x.view(original_shape)
620
+
621
+ # Convert log determinant to regularization loss
622
+ # Encourage the flow to preserve information (log_det close to 0)
623
+ reg_loss = 0.01 * log_det_total.abs().mean()
624
+
625
+ return x, reg_loss
626
+
627
+
628
+ class LearnablePreprocessor(nn.Module):
629
+ """Unified interface for learnable preprocessing methods."""
630
+
631
+ def __init__(self, config: AutoencoderConfig):
632
+ super().__init__()
633
+ self.config = config
634
+
635
+ if not config.has_preprocessing:
636
+ self.preprocessor = nn.Identity()
637
+ elif config.is_neural_scaler:
638
+ self.preprocessor = NeuralScaler(config)
639
+ elif config.is_normalizing_flow:
640
+ self.preprocessor = NormalizingFlowPreprocessor(config)
641
+ elif getattr(config, "is_minmax_scaler", False):
642
+ self.preprocessor = LearnableMinMaxScaler(config)
643
+ elif getattr(config, "is_robust_scaler", False):
644
+ self.preprocessor = LearnableRobustScaler(config)
645
+ elif getattr(config, "is_yeo_johnson", False):
646
+ self.preprocessor = LearnableYeoJohnsonPreprocessor(config)
647
+ else:
648
+ raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}")
649
+
650
+ def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
651
+ """
652
+ Apply preprocessing transformation.
653
+
654
+ Args:
655
+ x: Input tensor
656
+ inverse: Whether to apply inverse transformation
657
+
658
+ Returns:
659
+ Tuple of (transformed_tensor, regularization_loss)
660
+ """
661
+ if isinstance(self.preprocessor, nn.Identity):
662
+ return x, torch.tensor(0.0, device=x.device)
663
+
664
+ return self.preprocessor(x, inverse=inverse)
665
+
666
+
667
+ @dataclass
668
+ class AutoencoderOutput(ModelOutput):
669
+ """
670
+ Output type of AutoencoderModel.
671
+
672
+ Args:
673
+ last_hidden_state (torch.FloatTensor): The latent representation of the input.
674
+ reconstructed (torch.FloatTensor, optional): The reconstructed input.
675
+ hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers.
676
+ attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder.
677
+ preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing.
678
+ """
679
+
680
+ last_hidden_state: torch.FloatTensor = None
681
+ reconstructed: Optional[torch.FloatTensor] = None
682
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
683
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
684
+ preprocessing_loss: Optional[torch.FloatTensor] = None
685
+
686
+
687
+ @dataclass
688
+ class AutoencoderForReconstructionOutput(ModelOutput):
689
+ """
690
+ Output type of AutoencoderForReconstruction.
691
+
692
+ Args:
693
+ loss (torch.FloatTensor, optional): The reconstruction loss.
694
+ reconstructed (torch.FloatTensor): The reconstructed input.
695
+ last_hidden_state (torch.FloatTensor): The latent representation.
696
+ hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers.
697
+ preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing.
698
+ """
699
+
700
+ loss: Optional[torch.FloatTensor] = None
701
+ reconstructed: torch.FloatTensor = None
702
+ last_hidden_state: torch.FloatTensor = None
703
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
704
+ preprocessing_loss: Optional[torch.FloatTensor] = None
705
+
706
+
707
+ class AutoencoderEncoder(nn.Module):
708
+ """Encoder part of the autoencoder."""
709
+
710
+ def __init__(self, config: AutoencoderConfig):
711
+ super().__init__()
712
+ self.config = config
713
+
714
+ # Build encoder layers
715
+ layers = []
716
+ input_dim = config.input_dim
717
+
718
+ for hidden_dim in config.hidden_dims:
719
+ layers.append(nn.Linear(input_dim, hidden_dim))
720
+
721
+ if config.use_batch_norm:
722
+ layers.append(nn.BatchNorm1d(hidden_dim))
723
+
724
+ layers.append(self._get_activation(config.activation))
725
+
726
+ if config.dropout_rate > 0:
727
+ layers.append(nn.Dropout(config.dropout_rate))
728
+
729
+ input_dim = hidden_dim
730
+
731
+ self.encoder = nn.Sequential(*layers)
732
+
733
+ # For variational autoencoders, we need separate layers for mean and log variance
734
+ if config.is_variational:
735
+ self.fc_mu = nn.Linear(input_dim, config.latent_dim)
736
+ self.fc_logvar = nn.Linear(input_dim, config.latent_dim)
737
+ else:
738
+ # Standard encoder output
739
+ self.fc_out = nn.Linear(input_dim, config.latent_dim)
740
+
741
+ def _get_activation(self, activation: str) -> nn.Module:
742
+ """Get activation function by name."""
743
+ activations = {
744
+ "relu": nn.ReLU(),
745
+ "tanh": nn.Tanh(),
746
+ "sigmoid": nn.Sigmoid(),
747
+ "leaky_relu": nn.LeakyReLU(),
748
+ "gelu": nn.GELU(),
749
+ "swish": nn.SiLU(),
750
+ "silu": nn.SiLU(),
751
+ "elu": nn.ELU(),
752
+ "prelu": nn.PReLU(),
753
+ "relu6": nn.ReLU6(),
754
+ "hardtanh": nn.Hardtanh(),
755
+ "hardsigmoid": nn.Hardsigmoid(),
756
+ "hardswish": nn.Hardswish(),
757
+ "mish": nn.Mish(),
758
+ "softplus": nn.Softplus(),
759
+ "softsign": nn.Softsign(),
760
+ "tanhshrink": nn.Tanhshrink(),
761
+ "threshold": nn.Threshold(threshold=0.1, value=0),
762
+ }
763
+ return activations[activation]
764
+
765
+ def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
766
+ """Forward pass through encoder."""
767
+ # Add noise for denoising autoencoders
768
+ if self.config.is_denoising and self.training:
769
+ noise = torch.randn_like(x) * self.config.noise_factor
770
+ x = x + noise
771
+
772
+ encoded = self.encoder(x)
773
+
774
+ if self.config.is_variational:
775
+ # Variational autoencoder: return mean, log variance, and sampled latent
776
+ mu = self.fc_mu(encoded)
777
+ logvar = self.fc_logvar(encoded)
778
+
779
+ # Reparameterization trick
780
+ if self.training:
781
+ std = torch.exp(0.5 * logvar)
782
+ eps = torch.randn_like(std)
783
+ z = mu + eps * std
784
+ else:
785
+ z = mu # Use mean during inference
786
+
787
+ return z, mu, logvar
788
+ else:
789
+ # Standard autoencoder
790
+ latent = self.fc_out(encoded)
791
+
792
+ # Add sparsity constraint for sparse autoencoders
793
+ if self.config.is_sparse and self.training:
794
+ # Apply L1 regularization to encourage sparsity
795
+ latent = F.relu(latent) # Ensure non-negative activations
796
+
797
+ return latent
798
+
799
+
800
+ class AutoencoderDecoder(nn.Module):
801
+ """Decoder part of the autoencoder."""
802
+
803
+ def __init__(self, config: AutoencoderConfig):
804
+ super().__init__()
805
+ self.config = config
806
+
807
+ # Build decoder layers (reverse of encoder)
808
+ layers = []
809
+ input_dim = config.latent_dim
810
+ decoder_dims = config.decoder_dims + [config.input_dim]
811
+
812
+ for i, hidden_dim in enumerate(decoder_dims):
813
+ layers.append(nn.Linear(input_dim, hidden_dim))
814
+
815
+ # Don't add batch norm, activation, or dropout to the final layer
816
+ if i < len(decoder_dims) - 1:
817
+ if config.use_batch_norm:
818
+ layers.append(nn.BatchNorm1d(hidden_dim))
819
+
820
+ layers.append(self._get_activation(config.activation))
821
+
822
+ if config.dropout_rate > 0:
823
+ layers.append(nn.Dropout(config.dropout_rate))
824
+ else:
825
+ # Final layer - add appropriate activation based on reconstruction loss
826
+ if config.reconstruction_loss == "bce":
827
+ layers.append(nn.Sigmoid())
828
+
829
+ input_dim = hidden_dim
830
+
831
+ self.decoder = nn.Sequential(*layers)
832
+
833
+ def _get_activation(self, activation: str) -> nn.Module:
834
+ """Get activation function by name."""
835
+ activations = {
836
+ "relu": nn.ReLU(),
837
+ "tanh": nn.Tanh(),
838
+ "sigmoid": nn.Sigmoid(),
839
+ "leaky_relu": nn.LeakyReLU(),
840
+ "gelu": nn.GELU(),
841
+ "swish": nn.SiLU(),
842
+ "silu": nn.SiLU(),
843
+ "elu": nn.ELU(),
844
+ "prelu": nn.PReLU(),
845
+ "relu6": nn.ReLU6(),
846
+ "hardtanh": nn.Hardtanh(),
847
+ "hardsigmoid": nn.Hardsigmoid(),
848
+ "hardswish": nn.Hardswish(),
849
+ "mish": nn.Mish(),
850
+ "softplus": nn.Softplus(),
851
+ "softsign": nn.Softsign(),
852
+ "tanhshrink": nn.Tanhshrink(),
853
+ "threshold": nn.Threshold(threshold=0.1, value=0),
854
+ }
855
+ return activations[activation]
856
+
857
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
858
+ """Forward pass through decoder."""
859
+ return self.decoder(x)
860
+
861
+
862
+ class RecurrentEncoder(nn.Module):
863
+ """Recurrent encoder for sequence data."""
864
+
865
+ def __init__(self, config: AutoencoderConfig):
866
+ super().__init__()
867
+ self.config = config
868
+
869
+ # Get RNN class
870
+ if config.rnn_type == "lstm":
871
+ rnn_class = nn.LSTM
872
+ elif config.rnn_type == "gru":
873
+ rnn_class = nn.GRU
874
+ elif config.rnn_type == "rnn":
875
+ rnn_class = nn.RNN
876
+ else:
877
+ raise ValueError(f"Unknown RNN type: {config.rnn_type}")
878
+
879
+ # Create RNN layers
880
+ self.rnn = rnn_class(
881
+ input_size=config.input_dim,
882
+ hidden_size=config.latent_dim,
883
+ num_layers=config.num_layers,
884
+ batch_first=True,
885
+ dropout=config.dropout_rate if config.num_layers > 1 else 0,
886
+ bidirectional=config.bidirectional
887
+ )
888
+
889
+ # Projection layer for bidirectional RNN
890
+ if config.bidirectional:
891
+ self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim)
892
+ else:
893
+ self.projection = None
894
+
895
+ # Batch normalization
896
+ if config.use_batch_norm:
897
+ self.batch_norm = nn.BatchNorm1d(config.latent_dim)
898
+ else:
899
+ self.batch_norm = None
900
+
901
+ # Dropout
902
+ if config.dropout_rate > 0:
903
+ self.dropout = nn.Dropout(config.dropout_rate)
904
+ else:
905
+ self.dropout = None
906
+
907
+ def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
908
+ """
909
+ Forward pass through recurrent encoder.
910
+
911
+ Args:
912
+ x: Input tensor of shape (batch_size, seq_len, input_dim)
913
+ lengths: Sequence lengths for packed sequences (optional)
914
+
915
+ Returns:
916
+ Encoded representation or tuple for VAE
917
+ """
918
+ batch_size, seq_len, _ = x.shape
919
+
920
+ # Add noise for denoising autoencoders
921
+ if self.config.is_denoising and self.training:
922
+ noise = torch.randn_like(x) * self.config.noise_factor
923
+ x = x + noise
924
+
925
+ # Pack sequences if lengths provided
926
+ if lengths is not None:
927
+ x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
928
+
929
+ # RNN forward pass
930
+ if self.config.rnn_type == "lstm":
931
+ output, (hidden, cell) = self.rnn(x)
932
+ else:
933
+ output, hidden = self.rnn(x)
934
+ cell = None
935
+
936
+ # Unpack if necessary
937
+ if lengths is not None:
938
+ output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
939
+
940
+ # Use last hidden state as encoding
941
+ if self.config.bidirectional:
942
+ # Concatenate forward and backward hidden states
943
+ hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim)
944
+ hidden = hidden[-1] # Take last layer
945
+ hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) # Concatenate directions
946
+
947
+ # Project to latent dimension
948
+ if self.projection:
949
+ hidden = self.projection(hidden)
950
+ else:
951
+ hidden = hidden[-1] # Take last layer
952
+
953
+ # Apply batch normalization
954
+ if self.batch_norm:
955
+ hidden = self.batch_norm(hidden)
956
+
957
+ # Apply dropout
958
+ if self.dropout and self.training:
959
+ hidden = self.dropout(hidden)
960
+
961
+ # Handle variational encoding
962
+ if self.config.is_variational:
963
+ # Split hidden into mean and log variance
964
+ mu = hidden[:, :self.config.latent_dim // 2]
965
+ logvar = hidden[:, self.config.latent_dim // 2:]
966
+
967
+ # Reparameterization trick
968
+ if self.training:
969
+ std = torch.exp(0.5 * logvar)
970
+ eps = torch.randn_like(std)
971
+ z = mu + eps * std
972
+ else:
973
+ z = mu
974
+
975
+ return z, mu, logvar
976
+ else:
977
+ return hidden
978
+
979
+
980
+ class RecurrentDecoder(nn.Module):
981
+ """Recurrent decoder for sequence data."""
982
+
983
+ def __init__(self, config: AutoencoderConfig):
984
+ super().__init__()
985
+ self.config = config
986
+
987
+ # Get RNN class
988
+ if config.rnn_type == "lstm":
989
+ rnn_class = nn.LSTM
990
+ elif config.rnn_type == "gru":
991
+ rnn_class = nn.GRU
992
+ elif config.rnn_type == "rnn":
993
+ rnn_class = nn.RNN
994
+ else:
995
+ raise ValueError(f"Unknown RNN type: {config.rnn_type}")
996
+
997
+ # Create RNN layers
998
+ self.rnn = rnn_class(
999
+ input_size=config.latent_dim,
1000
+ hidden_size=config.latent_dim,
1001
+ num_layers=config.num_layers,
1002
+ batch_first=True,
1003
+ dropout=config.dropout_rate if config.num_layers > 1 else 0,
1004
+ bidirectional=False # Decoder is always unidirectional
1005
+ )
1006
+
1007
+ # Output projection
1008
+ self.output_projection = nn.Linear(config.latent_dim, config.input_dim)
1009
+
1010
+ # Batch normalization
1011
+ if config.use_batch_norm:
1012
+ self.batch_norm = nn.BatchNorm1d(config.latent_dim)
1013
+ else:
1014
+ self.batch_norm = None
1015
+
1016
+ # Dropout
1017
+ if config.dropout_rate > 0:
1018
+ self.dropout = nn.Dropout(config.dropout_rate)
1019
+ else:
1020
+ self.dropout = None
1021
+
1022
+ def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor:
1023
+ """
1024
+ Forward pass through recurrent decoder.
1025
+
1026
+ Args:
1027
+ z: Latent representation of shape (batch_size, latent_dim)
1028
+ target_length: Length of sequence to generate
1029
+ target_sequence: Target sequence for teacher forcing (optional)
1030
+
1031
+ Returns:
1032
+ Decoded sequence of shape (batch_size, seq_len, input_dim)
1033
+ """
1034
+ batch_size = z.size(0)
1035
+ device = z.device
1036
+
1037
+ # Initialize hidden state with latent representation
1038
+ if self.config.rnn_type == "lstm":
1039
+ h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1)
1040
+ c_0 = torch.zeros_like(h_0)
1041
+ hidden = (h_0, c_0)
1042
+ else:
1043
+ hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1)
1044
+
1045
+ outputs = []
1046
+
1047
+ # Initialize input (can be learned or zero)
1048
+ current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device)
1049
+
1050
+ for t in range(target_length):
1051
+ # Teacher forcing decision
1052
+ use_teacher_forcing = (target_sequence is not None and
1053
+ self.training and
1054
+ random.random() < self.config.teacher_forcing_ratio)
1055
+
1056
+ if use_teacher_forcing and t > 0:
1057
+ # Use previous target as input
1058
+ current_input = target_sequence[:, t-1:t, :]
1059
+ # Project to latent dimension if needed
1060
+ if current_input.size(-1) != self.config.latent_dim:
1061
+ current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device)
1062
+
1063
+ # RNN forward step
1064
+ if self.config.rnn_type == "lstm":
1065
+ output, hidden = self.rnn(current_input, hidden)
1066
+ else:
1067
+ output, hidden = self.rnn(current_input, hidden)
1068
+
1069
+ # Apply batch normalization and dropout
1070
+ output_flat = output.squeeze(1) # Remove sequence dimension
1071
+
1072
+ if self.batch_norm:
1073
+ output_flat = self.batch_norm(output_flat)
1074
+
1075
+ if self.dropout and self.training:
1076
+ output_flat = self.dropout(output_flat)
1077
+
1078
+ # Project to output dimension
1079
+ step_output = self.output_projection(output_flat)
1080
+ outputs.append(step_output.unsqueeze(1))
1081
+
1082
+ # Use output as next input (for non-teacher forcing)
1083
+ if not use_teacher_forcing:
1084
+ # Project output back to latent dimension for next step
1085
+ current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device)
1086
+
1087
+ # Concatenate all outputs
1088
+ return torch.cat(outputs, dim=1)
1089
+
1090
+
1091
+ class AutoencoderModel(PreTrainedModel):
1092
+ """
1093
+ The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top.
1094
+
1095
+ This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the
1096
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1097
+ etc.)
1098
+
1099
+ This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the
1100
+ PyTorch documentation for all matter related to general usage and behavior.
1101
+ """
1102
+
1103
+ config_class = AutoencoderConfig
1104
+ base_model_prefix = "autoencoder"
1105
+ supports_gradient_checkpointing = False
1106
+
1107
+ def __init__(self, config: AutoencoderConfig):
1108
+ super().__init__(config)
1109
+ self.config = config
1110
+
1111
+ # Initialize learnable preprocessing
1112
+ if config.has_preprocessing:
1113
+ self.preprocessor = LearnablePreprocessor(config)
1114
+ else:
1115
+ self.preprocessor = None
1116
+
1117
+ # Initialize encoder and decoder based on type
1118
+ if config.is_recurrent:
1119
+ self.encoder = RecurrentEncoder(config)
1120
+ self.decoder = RecurrentDecoder(config)
1121
+ else:
1122
+ self.encoder = AutoencoderEncoder(config)
1123
+ self.decoder = AutoencoderDecoder(config)
1124
+
1125
+ # Tie weights if specified
1126
+ if config.tie_weights:
1127
+ self._tie_weights()
1128
+
1129
+ # Initialize weights
1130
+ self.post_init()
1131
+
1132
+ def _tie_weights(self):
1133
+ """Tie encoder and decoder weights (transpose relationship)."""
1134
+ # This is a simplified weight tying - in practice, you might want more sophisticated tying
1135
+ pass
1136
+
1137
+ def get_input_embeddings(self):
1138
+ """Get input embeddings (not applicable for basic autoencoder)."""
1139
+ return None
1140
+
1141
+ def set_input_embeddings(self, value):
1142
+ """Set input embeddings (not applicable for basic autoencoder)."""
1143
+ pass
1144
+
1145
+ def forward(
1146
+ self,
1147
+ input_values: torch.Tensor,
1148
+ sequence_lengths: Optional[torch.Tensor] = None,
1149
+ target_length: Optional[int] = None,
1150
+ output_hidden_states: Optional[bool] = None,
1151
+ return_dict: Optional[bool] = None,
1152
+ ) -> Union[Tuple[torch.Tensor], AutoencoderOutput]:
1153
+ """
1154
+ Forward pass through the autoencoder.
1155
+
1156
+ Args:
1157
+ input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type:
1158
+ - Standard: (batch_size, input_dim)
1159
+ - Recurrent: (batch_size, seq_len, input_dim)
1160
+ sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE.
1161
+ target_length (int, optional): Target sequence length for recurrent decoder.
1162
+ output_hidden_states (bool, optional): Whether to return hidden states.
1163
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
1164
+
1165
+ Returns:
1166
+ AutoencoderOutput or tuple: The model outputs.
1167
+ """
1168
+ output_hidden_states = (
1169
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1170
+ )
1171
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1172
+
1173
+ # Apply learnable preprocessing
1174
+ preprocessing_loss = torch.tensor(0.0, device=input_values.device)
1175
+ if self.preprocessor is not None:
1176
+ input_values, preprocessing_loss = self.preprocessor(input_values, inverse=False)
1177
+
1178
+ # Handle different autoencoder types
1179
+ if self.config.is_recurrent:
1180
+ # Recurrent autoencoder
1181
+ if sequence_lengths is not None:
1182
+ encoder_output = self.encoder(input_values, sequence_lengths)
1183
+ else:
1184
+ encoder_output = self.encoder(input_values)
1185
+
1186
+ if self.config.is_variational:
1187
+ latent, mu, logvar = encoder_output
1188
+ self._mu = mu
1189
+ self._logvar = logvar
1190
+ else:
1191
+ latent = encoder_output
1192
+ self._mu = None
1193
+ self._logvar = None
1194
+
1195
+ # Determine target length for decoder
1196
+ if target_length is None:
1197
+ if self.config.sequence_length is not None:
1198
+ target_length = self.config.sequence_length
1199
+ else:
1200
+ target_length = input_values.size(1) # Use input sequence length
1201
+
1202
+ # Decode latent back to sequence space
1203
+ reconstructed = self.decoder(latent, target_length, input_values if self.training else None)
1204
+ else:
1205
+ # Standard autoencoder
1206
+ encoder_output = self.encoder(input_values)
1207
+
1208
+ if self.config.is_variational:
1209
+ latent, mu, logvar = encoder_output
1210
+ self._mu = mu
1211
+ self._logvar = logvar
1212
+ else:
1213
+ latent = encoder_output
1214
+ self._mu = None
1215
+ self._logvar = None
1216
+
1217
+ # Decode latent back to input space
1218
+ reconstructed = self.decoder(latent)
1219
+
1220
+ # Apply inverse preprocessing to reconstruction
1221
+ if self.preprocessor is not None and self.config.learn_inverse_preprocessing:
1222
+ reconstructed, inverse_loss = self.preprocessor(reconstructed, inverse=True)
1223
+ preprocessing_loss += inverse_loss
1224
+
1225
+ hidden_states = None
1226
+ if output_hidden_states:
1227
+ if self.config.is_variational:
1228
+ hidden_states = (latent, mu, logvar)
1229
+ else:
1230
+ hidden_states = (latent,)
1231
+
1232
+ if not return_dict:
1233
+ return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None)
1234
+
1235
+ return AutoencoderOutput(
1236
+ last_hidden_state=latent,
1237
+ reconstructed=reconstructed,
1238
+ hidden_states=hidden_states,
1239
+ preprocessing_loss=preprocessing_loss,
1240
+ )
1241
+
1242
+
1243
+ class AutoencoderForReconstruction(PreTrainedModel):
1244
+ """
1245
+ Autoencoder Model with a reconstruction head on top for reconstruction tasks.
1246
+
1247
+ This model inherits from PreTrainedModel and adds a reconstruction loss calculation.
1248
+ """
1249
+
1250
+ config_class = AutoencoderConfig
1251
+ base_model_prefix = "autoencoder"
1252
+
1253
+ def __init__(self, config: AutoencoderConfig):
1254
+ super().__init__(config)
1255
+ self.config = config
1256
+
1257
+ # Initialize the base autoencoder model
1258
+ self.autoencoder = AutoencoderModel(config)
1259
+
1260
+ # Initialize weights
1261
+ self.post_init()
1262
+
1263
+ def get_input_embeddings(self):
1264
+ """Get input embeddings."""
1265
+ return self.autoencoder.get_input_embeddings()
1266
+
1267
+ def set_input_embeddings(self, value):
1268
+ """Set input embeddings."""
1269
+ self.autoencoder.set_input_embeddings(value)
1270
+
1271
+ def _compute_reconstruction_loss(
1272
+ self,
1273
+ reconstructed: torch.Tensor,
1274
+ target: torch.Tensor
1275
+ ) -> torch.Tensor:
1276
+ """Compute reconstruction loss based on the configured loss type."""
1277
+ if self.config.reconstruction_loss == "mse":
1278
+ return F.mse_loss(reconstructed, target, reduction="mean")
1279
+ elif self.config.reconstruction_loss == "bce":
1280
+ return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean")
1281
+ elif self.config.reconstruction_loss == "l1":
1282
+ return F.l1_loss(reconstructed, target, reduction="mean")
1283
+ elif self.config.reconstruction_loss == "huber":
1284
+ return F.huber_loss(reconstructed, target, reduction="mean")
1285
+ elif self.config.reconstruction_loss == "smooth_l1":
1286
+ return F.smooth_l1_loss(reconstructed, target, reduction="mean")
1287
+ elif self.config.reconstruction_loss == "kl_div":
1288
+ return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean")
1289
+ elif self.config.reconstruction_loss == "cosine":
1290
+ return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean()
1291
+ elif self.config.reconstruction_loss == "focal":
1292
+ return self._focal_loss(reconstructed, target)
1293
+ elif self.config.reconstruction_loss == "dice":
1294
+ return self._dice_loss(reconstructed, target)
1295
+ elif self.config.reconstruction_loss == "tversky":
1296
+ return self._tversky_loss(reconstructed, target)
1297
+ elif self.config.reconstruction_loss == "ssim":
1298
+ return self._ssim_loss(reconstructed, target)
1299
+ elif self.config.reconstruction_loss == "perceptual":
1300
+ return self._perceptual_loss(reconstructed, target)
1301
+ else:
1302
+ raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}")
1303
+
1304
+ def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor:
1305
+ """Compute focal loss for handling class imbalance."""
1306
+ ce_loss = F.mse_loss(pred, target, reduction="none")
1307
+ pt = torch.exp(-ce_loss)
1308
+ focal_loss = alpha * (1 - pt) ** gamma * ce_loss
1309
+ return focal_loss.mean()
1310
+
1311
+ def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor:
1312
+ """Compute Dice loss for segmentation-like tasks."""
1313
+ pred_flat = pred.view(-1)
1314
+ target_flat = target.view(-1)
1315
+ intersection = (pred_flat * target_flat).sum()
1316
+ dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth)
1317
+ return 1 - dice
1318
+
1319
+ def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor:
1320
+ """Compute Tversky loss, a generalization of Dice loss."""
1321
+ pred_flat = pred.view(-1)
1322
+ target_flat = target.view(-1)
1323
+ true_pos = (pred_flat * target_flat).sum()
1324
+ false_neg = (target_flat * (1 - pred_flat)).sum()
1325
+ false_pos = ((1 - target_flat) * pred_flat).sum()
1326
+ tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth)
1327
+ return 1 - tversky
1328
+
1329
+ def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1330
+ """Compute SSIM-based loss (simplified version)."""
1331
+ # Simplified SSIM for 1D data
1332
+ mu1 = pred.mean(dim=-1, keepdim=True)
1333
+ mu2 = target.mean(dim=-1, keepdim=True)
1334
+ sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True)
1335
+ sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True)
1336
+ sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True)
1337
+
1338
+ c1, c2 = 0.01, 0.03
1339
+ ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2))
1340
+ return 1 - ssim.mean()
1341
+
1342
+ def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
1343
+ """Compute perceptual loss (simplified version using feature differences)."""
1344
+ # For simplicity, use L2 loss on normalized features
1345
+ pred_norm = F.normalize(pred, p=2, dim=-1)
1346
+ target_norm = F.normalize(target, p=2, dim=-1)
1347
+ return F.mse_loss(pred_norm, target_norm)
1348
+
1349
+ def forward(
1350
+ self,
1351
+ input_values: torch.Tensor,
1352
+ labels: Optional[torch.Tensor] = None,
1353
+ sequence_lengths: Optional[torch.Tensor] = None,
1354
+ target_length: Optional[int] = None,
1355
+ output_hidden_states: Optional[bool] = None,
1356
+ return_dict: Optional[bool] = None,
1357
+ ) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]:
1358
+ """
1359
+ Forward pass with reconstruction loss calculation.
1360
+
1361
+ Args:
1362
+ input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type:
1363
+ - Standard: (batch_size, input_dim)
1364
+ - Recurrent: (batch_size, seq_len, input_dim)
1365
+ labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values.
1366
+ sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE.
1367
+ target_length (int, optional): Target sequence length for recurrent decoder.
1368
+ output_hidden_states (bool, optional): Whether to return hidden states.
1369
+ return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple.
1370
+
1371
+ Returns:
1372
+ AutoencoderForReconstructionOutput or tuple: The model outputs including loss.
1373
+ """
1374
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1375
+
1376
+ # If no labels provided, use input as target (standard autoencoder)
1377
+ if labels is None:
1378
+ labels = input_values
1379
+
1380
+ # Forward pass through autoencoder
1381
+ outputs = self.autoencoder(
1382
+ input_values=input_values,
1383
+ sequence_lengths=sequence_lengths,
1384
+ target_length=target_length,
1385
+ output_hidden_states=output_hidden_states,
1386
+ return_dict=True,
1387
+ )
1388
+
1389
+ reconstructed = outputs.reconstructed
1390
+ latent = outputs.last_hidden_state
1391
+ hidden_states = outputs.hidden_states
1392
+
1393
+ # Compute reconstruction loss
1394
+ recon_loss = self._compute_reconstruction_loss(reconstructed, labels)
1395
+
1396
+ # Add regularization losses based on autoencoder type
1397
+ total_loss = recon_loss
1398
+
1399
+ # Add preprocessing loss if available
1400
+ if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None:
1401
+ total_loss += outputs.preprocessing_loss
1402
+
1403
+ if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None:
1404
+ # KL divergence loss for variational autoencoders
1405
+ kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp())
1406
+ kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) # Normalize by batch size and latent dim
1407
+ total_loss = recon_loss + self.config.beta * kl_loss
1408
+
1409
+ elif self.config.is_sparse:
1410
+ # Sparsity loss for sparse autoencoders
1411
+ latent = outputs.last_hidden_state
1412
+ sparsity_loss = torch.mean(torch.abs(latent)) # L1 sparsity
1413
+ total_loss = recon_loss + 0.1 * sparsity_loss # Sparsity weight
1414
+
1415
+ elif self.config.is_contractive:
1416
+ # Contractive loss - penalize large gradients of hidden representation w.r.t. input
1417
+ latent = outputs.last_hidden_state
1418
+ latent.retain_grad()
1419
+ if latent.grad is not None:
1420
+ contractive_loss = torch.sum(latent.grad ** 2)
1421
+ total_loss = recon_loss + 0.1 * contractive_loss
1422
+
1423
+ loss = total_loss
1424
+
1425
+ if not return_dict:
1426
+ output = (reconstructed, latent)
1427
+ if hidden_states is not None:
1428
+ output = output + (hidden_states,)
1429
+ return ((loss,) + output) if loss is not None else output
1430
+
1431
+ return AutoencoderForReconstructionOutput(
1432
+ loss=loss,
1433
+ reconstructed=reconstructed,
1434
+ last_hidden_state=latent,
1435
+ hidden_states=hidden_states,
1436
+ preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None,
1437
+ )
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:14e56e5ad0c4b49490b81fd03efb444c425ac02e5b4a9dc8cb26ecb1764b2c3d
3
+ size 5777