""" PyTorch Autoencoder model for Hugging Face Transformers. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, Union, Dict, Any, List from dataclasses import dataclass import random from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutput from transformers.utils import ModelOutput from configuration_autoencoder import AutoencoderConfig class NeuralScaler(nn.Module): """Learnable alternative to StandardScaler using neural networks.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim # Networks to learn data-dependent statistics self.mean_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim) ) self.std_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus() # Ensure positive standard deviation ) # Learnable affine transformation parameters self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) # Running statistics for inference (like BatchNorm) self.register_buffer('running_mean', torch.zeros(input_dim)) self.register_buffer('running_std', torch.ones(input_dim)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) # Momentum for running statistics self.momentum = 0.1 def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through neural scaler. Args: x: Input tensor (2D or 3D) inverse: Whether to apply inverse transformation Returns: Tuple of (transformed_tensor, regularization_loss) """ if inverse: return self._inverse_transform(x) # Handle both 2D and 3D tensors original_shape = x.shape if x.dim() == 3: # Reshape (batch, seq, features) -> (batch*seq, features) x = x.view(-1, x.size(-1)) if self.training: # Training mode: learn statistics from current batch batch_mean = x.mean(dim=0, keepdim=True) batch_std = x.std(dim=0, keepdim=True) # Learn data-dependent adjustments learned_mean_adj = self.mean_estimator(batch_mean) learned_std_adj = self.std_estimator(batch_std) # Combine batch statistics with learned adjustments effective_mean = batch_mean + learned_mean_adj effective_std = batch_std + learned_std_adj + 1e-8 # Update running statistics with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_mean.copy_(batch_mean.squeeze()) self.running_std.copy_(batch_std.squeeze()) else: self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) else: # Inference mode: use running statistics effective_mean = self.running_mean.unsqueeze(0) effective_std = self.running_std.unsqueeze(0) + 1e-8 # Normalize normalized = (x - effective_mean) / effective_std # Apply learnable affine transformation transformed = normalized * self.weight + self.bias # Reshape back to original shape if needed if len(original_shape) == 3: transformed = transformed.view(original_shape) # Regularization loss to encourage meaningful learning reg_loss = 0.01 * (self.weight.var() + self.bias.var()) return transformed, reg_loss def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """Apply inverse transformation to get back original scale.""" if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) # Handle both 2D and 3D tensors original_shape = x.shape if x.dim() == 3: # Reshape (batch, seq, features) -> (batch*seq, features) x = x.view(-1, x.size(-1)) # Reverse affine transformation x = (x - self.bias) / (self.weight + 1e-8) # Reverse normalization using running statistics effective_mean = self.running_mean.unsqueeze(0) effective_std = self.running_std.unsqueeze(0) + 1e-8 x = x * effective_std + effective_mean # Reshape back to original shape if needed if len(original_shape) == 3: x = x.view(original_shape) return x, torch.tensor(0.0, device=x.device) class LearnableMinMaxScaler(nn.Module): """Learnable MinMax scaler that adapts bounds during training. Scales features to [0, 1] using batch min/range with learnable adjustments and a learnable affine transform. Supports 2D (B, F) and 3D (B, T, F) inputs. """ def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim # Networks to learn adjustments to batch min and range self.min_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), ) self.range_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus(), # Ensure positive adjustment to range ) # Learnable affine transformation parameters self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) # Running statistics for inference self.register_buffer("running_min", torch.zeros(input_dim)) self.register_buffer("running_range", torch.ones(input_dim)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.momentum = 0.1 def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if inverse: return self._inverse_transform(x) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) eps = 1e-8 if self.training: batch_min = x.min(dim=0, keepdim=True).values batch_max = x.max(dim=0, keepdim=True).values batch_range = (batch_max - batch_min).clamp_min(eps) # Learn adjustments learned_min_adj = self.min_estimator(batch_min) learned_range_adj = self.range_estimator(batch_range) effective_min = batch_min + learned_min_adj effective_range = batch_range + learned_range_adj + eps # Update running stats with raw batch min/range for stable inversion with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_min.copy_(batch_min.squeeze()) self.running_range.copy_(batch_range.squeeze()) else: self.running_min.mul_(1 - self.momentum).add_(batch_min.squeeze(), alpha=self.momentum) self.running_range.mul_(1 - self.momentum).add_(batch_range.squeeze(), alpha=self.momentum) else: effective_min = self.running_min.unsqueeze(0) effective_range = self.running_range.unsqueeze(0) # Scale to [0, 1] scaled = (x - effective_min) / effective_range # Learnable affine transform transformed = scaled * self.weight + self.bias if len(original_shape) == 3: transformed = transformed.view(original_shape) # Regularization: encourage non-degenerate range and modest affine params reg_loss = 0.01 * (self.weight.var() + self.bias.var()) if self.training: reg_loss = reg_loss + 0.001 * (1.0 / effective_range.clamp_min(1e-3)).mean() return transformed, reg_loss def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) # Reverse affine x = (x - self.bias) / (self.weight + 1e-8) # Reverse MinMax using running stats x = x * self.running_range.unsqueeze(0) + self.running_min.unsqueeze(0) if len(original_shape) == 3: x = x.view(original_shape) return x, torch.tensor(0.0, device=x.device) class LearnableRobustScaler(nn.Module): """Learnable Robust scaler using median and IQR with learnable adjustments. Normalizes as (x - median) / IQR with learnable adjustments and an affine head. Supports 2D (B, F) and 3D (B, T, F) inputs. """ def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim self.median_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), ) self.iqr_estimator = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, input_dim), nn.Softplus(), # Ensure positive IQR adjustment ) self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) self.register_buffer("running_median", torch.zeros(input_dim)) self.register_buffer("running_iqr", torch.ones(input_dim)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.momentum = 0.1 def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if inverse: return self._inverse_transform(x) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) eps = 1e-8 if self.training: qs = torch.quantile(x, torch.tensor([0.25, 0.5, 0.75], device=x.device), dim=0) q25, med, q75 = qs[0:1, :], qs[1:2, :], qs[2:3, :] iqr = (q75 - q25).clamp_min(eps) learned_med_adj = self.median_estimator(med) learned_iqr_adj = self.iqr_estimator(iqr) effective_median = med + learned_med_adj effective_iqr = iqr + learned_iqr_adj + eps with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_median.copy_(med.squeeze()) self.running_iqr.copy_(iqr.squeeze()) else: self.running_median.mul_(1 - self.momentum).add_(med.squeeze(), alpha=self.momentum) self.running_iqr.mul_(1 - self.momentum).add_(iqr.squeeze(), alpha=self.momentum) else: effective_median = self.running_median.unsqueeze(0) effective_iqr = self.running_iqr.unsqueeze(0) normalized = (x - effective_median) / effective_iqr transformed = normalized * self.weight + self.bias if len(original_shape) == 3: transformed = transformed.view(original_shape) reg_loss = 0.01 * (self.weight.var() + self.bias.var()) if self.training: reg_loss = reg_loss + 0.001 * (1.0 / effective_iqr.clamp_min(1e-3)).mean() return transformed, reg_loss def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) original_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) x = (x - self.bias) / (self.weight + 1e-8) x = x * self.running_iqr.unsqueeze(0) + self.running_median.unsqueeze(0) if len(original_shape) == 3: x = x.view(original_shape) return x, torch.tensor(0.0, device=x.device) class LearnableYeoJohnsonPreprocessor(nn.Module): """Learnable Yeo-Johnson power transform with per-feature λ and affine head. Applies Yeo-Johnson transform elementwise with learnable lambda per feature, followed by standardization and a learnable affine transform. Supports 2D and 3D inputs. """ def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim # Learnable lambda per feature (unconstrained). Initialize around 1.0 self.lmbda = nn.Parameter(torch.ones(input_dim)) # Learnable affine parameters after standardization self.weight = nn.Parameter(torch.ones(input_dim)) self.bias = nn.Parameter(torch.zeros(input_dim)) # Running stats for transformed data self.register_buffer("running_mean", torch.zeros(input_dim)) self.register_buffer("running_std", torch.ones(input_dim)) self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) self.momentum = 0.1 def _yeo_johnson(self, x: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: eps = 1e-6 lmbda = lmbda.unsqueeze(0) # broadcast over batch pos = x >= 0 # For x >= 0 if_part = torch.where( torch.abs(lmbda) > eps, ((x + 1.0).clamp_min(eps) ** lmbda - 1.0) / lmbda, torch.log((x + 1.0).clamp_min(eps)), ) # For x < 0 two_minus_lambda = 2.0 - lmbda else_part = torch.where( torch.abs(two_minus_lambda) > eps, -(((1.0 - x).clamp_min(eps)) ** two_minus_lambda - 1.0) / two_minus_lambda, -torch.log((1.0 - x).clamp_min(eps)), ) return torch.where(pos, if_part, else_part) def _yeo_johnson_inverse(self, y: torch.Tensor, lmbda: torch.Tensor) -> torch.Tensor: eps = 1e-6 lmbda = lmbda.unsqueeze(0) pos = y >= 0 # Inverse for y >= 0 x_pos = torch.where( torch.abs(lmbda) > eps, (y * lmbda + 1.0).clamp_min(eps) ** (1.0 / lmbda) - 1.0, torch.exp(y) - 1.0, ) # Inverse for y < 0 two_minus_lambda = 2.0 - lmbda x_neg = torch.where( torch.abs(two_minus_lambda) > eps, 1.0 - (1.0 - y * two_minus_lambda).clamp_min(eps) ** (1.0 / two_minus_lambda), 1.0 - torch.exp(-y), ) return torch.where(pos, x_pos, x_neg) def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: if inverse: return self._inverse_transform(x) orig_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) # Apply Yeo-Johnson y = self._yeo_johnson(x, self.lmbda) # Batch stats and running stats on transformed data if self.training: batch_mean = y.mean(dim=0, keepdim=True) batch_std = y.std(dim=0, keepdim=True).clamp_min(1e-6) with torch.no_grad(): self.num_batches_tracked += 1 if self.num_batches_tracked == 1: self.running_mean.copy_(batch_mean.squeeze()) self.running_std.copy_(batch_std.squeeze()) else: self.running_mean.mul_(1 - self.momentum).add_(batch_mean.squeeze(), alpha=self.momentum) self.running_std.mul_(1 - self.momentum).add_(batch_std.squeeze(), alpha=self.momentum) mean = batch_mean std = batch_std else: mean = self.running_mean.unsqueeze(0) std = self.running_std.unsqueeze(0) y_norm = (y - mean) / std out = y_norm * self.weight + self.bias if len(orig_shape) == 3: out = out.view(orig_shape) # Regularize lambda to avoid extreme values; encourage identity around 1 reg = 0.001 * (self.lmbda - 1.0).pow(2).mean() + 0.01 * (self.weight.var() + self.bias.var()) return out, reg def _inverse_transform(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if not self.config.learn_inverse_preprocessing: return x, torch.tensor(0.0, device=x.device) orig_shape = x.shape if x.dim() == 3: x = x.view(-1, x.size(-1)) # Reverse affine and normalization with running stats y = (x - self.bias) / (self.weight + 1e-8) y = y * self.running_std.unsqueeze(0) + self.running_mean.unsqueeze(0) # Inverse Yeo-Johnson out = self._yeo_johnson_inverse(y, self.lmbda) if len(orig_shape) == 3: out = out.view(orig_shape) return out, torch.tensor(0.0, device=x.device) class CouplingLayer(nn.Module): """Coupling layer for normalizing flows.""" def __init__(self, input_dim: int, hidden_dim: int = 64, mask_type: str = "alternating"): super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim # Create mask for coupling if mask_type == "alternating": self.register_buffer('mask', torch.arange(input_dim) % 2) elif mask_type == "half": mask = torch.zeros(input_dim) mask[:input_dim // 2] = 1 self.register_buffer('mask', mask) else: raise ValueError(f"Unknown mask type: {mask_type}") # Scale and translation networks masked_dim = int(self.mask.sum().item()) unmasked_dim = input_dim - masked_dim self.scale_net = nn.Sequential( nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim), nn.Tanh() # Bounded output for stability ) self.translate_net = nn.Sequential( nn.Linear(masked_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, unmasked_dim) ) def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through coupling layer. Args: x: Input tensor inverse: Whether to apply inverse transformation Returns: Tuple of (transformed_tensor, log_determinant) """ mask = self.mask.bool() x_masked = x[:, mask] x_unmasked = x[:, ~mask] # Compute scale and translation s = self.scale_net(x_masked) t = self.translate_net(x_masked) if not inverse: # Forward transformation y_unmasked = x_unmasked * torch.exp(s) + t log_det = s.sum(dim=1) else: # Inverse transformation y_unmasked = (x_unmasked - t) * torch.exp(-s) log_det = -s.sum(dim=1) # Reconstruct output y = torch.zeros_like(x) y[:, mask] = x_masked y[:, ~mask] = y_unmasked return y, log_det class NormalizingFlowPreprocessor(nn.Module): """Normalizing flow for learnable data preprocessing.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config input_dim = config.input_dim hidden_dim = config.preprocessing_hidden_dim num_layers = config.flow_coupling_layers # Create coupling layers with alternating masks self.layers = nn.ModuleList() for i in range(num_layers): mask_type = "alternating" if i % 2 == 0 else "half" self.layers.append(CouplingLayer(input_dim, hidden_dim, mask_type)) # Optional: Add batch normalization between layers if config.use_batch_norm: self.batch_norms = nn.ModuleList([ nn.BatchNorm1d(input_dim) for _ in range(num_layers - 1) ]) else: self.batch_norms = None def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through normalizing flow. Args: x: Input tensor (2D or 3D) inverse: Whether to apply inverse transformation Returns: Tuple of (transformed_tensor, total_log_determinant) """ # Handle both 2D and 3D tensors original_shape = x.shape if x.dim() == 3: # Reshape (batch, seq, features) -> (batch*seq, features) x = x.view(-1, x.size(-1)) log_det_total = torch.zeros(x.size(0), device=x.device) if not inverse: # Forward pass for i, layer in enumerate(self.layers): x, log_det = layer(x, inverse=False) log_det_total += log_det # Apply batch normalization (except for last layer) if self.batch_norms and i < len(self.layers) - 1: x = self.batch_norms[i](x) else: # Inverse pass for i, layer in enumerate(reversed(self.layers)): # Reverse batch normalization (except for first layer in reverse) if self.batch_norms and i > 0: # Note: This is approximate inverse of batch norm bn_idx = len(self.layers) - 1 - i x = self.batch_norms[bn_idx](x) x, log_det = layer(x, inverse=True) log_det_total += log_det # Reshape back to original shape if needed if len(original_shape) == 3: x = x.view(original_shape) # Convert log determinant to regularization loss # Encourage the flow to preserve information (log_det close to 0) reg_loss = 0.01 * log_det_total.abs().mean() return x, reg_loss class LearnablePreprocessor(nn.Module): """Unified interface for learnable preprocessing methods.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config if not config.has_preprocessing: self.preprocessor = nn.Identity() elif config.is_neural_scaler: self.preprocessor = NeuralScaler(config) elif config.is_normalizing_flow: self.preprocessor = NormalizingFlowPreprocessor(config) elif getattr(config, "is_minmax_scaler", False): self.preprocessor = LearnableMinMaxScaler(config) elif getattr(config, "is_robust_scaler", False): self.preprocessor = LearnableRobustScaler(config) elif getattr(config, "is_yeo_johnson", False): self.preprocessor = LearnableYeoJohnsonPreprocessor(config) else: raise ValueError(f"Unknown preprocessing type: {config.preprocessing_type}") def forward(self, x: torch.Tensor, inverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply preprocessing transformation. Args: x: Input tensor inverse: Whether to apply inverse transformation Returns: Tuple of (transformed_tensor, regularization_loss) """ if isinstance(self.preprocessor, nn.Identity): return x, torch.tensor(0.0, device=x.device) return self.preprocessor(x, inverse=inverse) @dataclass class AutoencoderOutput(ModelOutput): """ Output type of AutoencoderModel. Args: last_hidden_state (torch.FloatTensor): The latent representation of the input. reconstructed (torch.FloatTensor, optional): The reconstructed input. hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. attentions (tuple(torch.FloatTensor), optional): Not used in basic autoencoder. preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. """ last_hidden_state: torch.FloatTensor = None reconstructed: Optional[torch.FloatTensor] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None preprocessing_loss: Optional[torch.FloatTensor] = None @dataclass class AutoencoderForReconstructionOutput(ModelOutput): """ Output type of AutoencoderForReconstruction. Args: loss (torch.FloatTensor, optional): The reconstruction loss. reconstructed (torch.FloatTensor): The reconstructed input. last_hidden_state (torch.FloatTensor): The latent representation. hidden_states (tuple(torch.FloatTensor), optional): Hidden states of the encoder layers. preprocessing_loss (torch.FloatTensor, optional): Loss from learnable preprocessing. """ loss: Optional[torch.FloatTensor] = None reconstructed: torch.FloatTensor = None last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None preprocessing_loss: Optional[torch.FloatTensor] = None class AutoencoderEncoder(nn.Module): """Encoder part of the autoencoder.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Build encoder layers layers = [] input_dim = config.input_dim for hidden_dim in config.hidden_dims: layers.append(nn.Linear(input_dim, hidden_dim)) if config.use_batch_norm: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(self._get_activation(config.activation)) if config.dropout_rate > 0: layers.append(nn.Dropout(config.dropout_rate)) input_dim = hidden_dim self.encoder = nn.Sequential(*layers) # For variational autoencoders, we need separate layers for mean and log variance if config.is_variational: self.fc_mu = nn.Linear(input_dim, config.latent_dim) self.fc_logvar = nn.Linear(input_dim, config.latent_dim) else: # Standard encoder output self.fc_out = nn.Linear(input_dim, config.latent_dim) def _get_activation(self, activation: str) -> nn.Module: """Get activation function by name.""" activations = { "relu": nn.ReLU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), "leaky_relu": nn.LeakyReLU(), "gelu": nn.GELU(), "swish": nn.SiLU(), "silu": nn.SiLU(), "elu": nn.ELU(), "prelu": nn.PReLU(), "relu6": nn.ReLU6(), "hardtanh": nn.Hardtanh(), "hardsigmoid": nn.Hardsigmoid(), "hardswish": nn.Hardswish(), "mish": nn.Mish(), "softplus": nn.Softplus(), "softsign": nn.Softsign(), "tanhshrink": nn.Tanhshrink(), "threshold": nn.Threshold(threshold=0.1, value=0), } return activations[activation] def forward(self, x: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Forward pass through encoder.""" # Add noise for denoising autoencoders if self.config.is_denoising and self.training: noise = torch.randn_like(x) * self.config.noise_factor x = x + noise encoded = self.encoder(x) if self.config.is_variational: # Variational autoencoder: return mean, log variance, and sampled latent mu = self.fc_mu(encoded) logvar = self.fc_logvar(encoded) # Reparameterization trick if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std else: z = mu # Use mean during inference return z, mu, logvar else: # Standard autoencoder latent = self.fc_out(encoded) # Add sparsity constraint for sparse autoencoders if self.config.is_sparse and self.training: # Apply L1 regularization to encourage sparsity latent = F.relu(latent) # Ensure non-negative activations return latent class AutoencoderDecoder(nn.Module): """Decoder part of the autoencoder.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Build decoder layers (reverse of encoder) layers = [] input_dim = config.latent_dim decoder_dims = config.decoder_dims + [config.input_dim] for i, hidden_dim in enumerate(decoder_dims): layers.append(nn.Linear(input_dim, hidden_dim)) # Don't add batch norm, activation, or dropout to the final layer if i < len(decoder_dims) - 1: if config.use_batch_norm: layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(self._get_activation(config.activation)) if config.dropout_rate > 0: layers.append(nn.Dropout(config.dropout_rate)) else: # Final layer - add appropriate activation based on reconstruction loss if config.reconstruction_loss == "bce": layers.append(nn.Sigmoid()) input_dim = hidden_dim self.decoder = nn.Sequential(*layers) def _get_activation(self, activation: str) -> nn.Module: """Get activation function by name.""" activations = { "relu": nn.ReLU(), "tanh": nn.Tanh(), "sigmoid": nn.Sigmoid(), "leaky_relu": nn.LeakyReLU(), "gelu": nn.GELU(), "swish": nn.SiLU(), "silu": nn.SiLU(), "elu": nn.ELU(), "prelu": nn.PReLU(), "relu6": nn.ReLU6(), "hardtanh": nn.Hardtanh(), "hardsigmoid": nn.Hardsigmoid(), "hardswish": nn.Hardswish(), "mish": nn.Mish(), "softplus": nn.Softplus(), "softsign": nn.Softsign(), "tanhshrink": nn.Tanhshrink(), "threshold": nn.Threshold(threshold=0.1, value=0), } return activations[activation] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through decoder.""" return self.decoder(x) class RecurrentEncoder(nn.Module): """Recurrent encoder for sequence data.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Get RNN class if config.rnn_type == "lstm": rnn_class = nn.LSTM elif config.rnn_type == "gru": rnn_class = nn.GRU elif config.rnn_type == "rnn": rnn_class = nn.RNN else: raise ValueError(f"Unknown RNN type: {config.rnn_type}") # Create RNN layers self.rnn = rnn_class( input_size=config.input_dim, hidden_size=config.latent_dim, num_layers=config.num_layers, batch_first=True, dropout=config.dropout_rate if config.num_layers > 1 else 0, bidirectional=config.bidirectional ) # Projection layer for bidirectional RNN if config.bidirectional: self.projection = nn.Linear(config.latent_dim * 2, config.latent_dim) else: self.projection = None # Batch normalization if config.use_batch_norm: self.batch_norm = nn.BatchNorm1d(config.latent_dim) else: self.batch_norm = None # Dropout if config.dropout_rate > 0: self.dropout = nn.Dropout(config.dropout_rate) else: self.dropout = None def forward(self, x: torch.Tensor, lengths: Optional[torch.Tensor] = None) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Forward pass through recurrent encoder. Args: x: Input tensor of shape (batch_size, seq_len, input_dim) lengths: Sequence lengths for packed sequences (optional) Returns: Encoded representation or tuple for VAE """ batch_size, seq_len, _ = x.shape # Add noise for denoising autoencoders if self.config.is_denoising and self.training: noise = torch.randn_like(x) * self.config.noise_factor x = x + noise # Pack sequences if lengths provided if lengths is not None: x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) # RNN forward pass if self.config.rnn_type == "lstm": output, (hidden, cell) = self.rnn(x) else: output, hidden = self.rnn(x) cell = None # Unpack if necessary if lengths is not None: output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True) # Use last hidden state as encoding if self.config.bidirectional: # Concatenate forward and backward hidden states hidden = hidden.view(self.config.num_layers, 2, batch_size, self.config.latent_dim) hidden = hidden[-1] # Take last layer hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) # Concatenate directions # Project to latent dimension if self.projection: hidden = self.projection(hidden) else: hidden = hidden[-1] # Take last layer # Apply batch normalization if self.batch_norm: hidden = self.batch_norm(hidden) # Apply dropout if self.dropout and self.training: hidden = self.dropout(hidden) # Handle variational encoding if self.config.is_variational: # Split hidden into mean and log variance mu = hidden[:, :self.config.latent_dim // 2] logvar = hidden[:, self.config.latent_dim // 2:] # Reparameterization trick if self.training: std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std else: z = mu return z, mu, logvar else: return hidden class RecurrentDecoder(nn.Module): """Recurrent decoder for sequence data.""" def __init__(self, config: AutoencoderConfig): super().__init__() self.config = config # Get RNN class if config.rnn_type == "lstm": rnn_class = nn.LSTM elif config.rnn_type == "gru": rnn_class = nn.GRU elif config.rnn_type == "rnn": rnn_class = nn.RNN else: raise ValueError(f"Unknown RNN type: {config.rnn_type}") # Create RNN layers self.rnn = rnn_class( input_size=config.latent_dim, hidden_size=config.latent_dim, num_layers=config.num_layers, batch_first=True, dropout=config.dropout_rate if config.num_layers > 1 else 0, bidirectional=False # Decoder is always unidirectional ) # Output projection self.output_projection = nn.Linear(config.latent_dim, config.input_dim) # Batch normalization if config.use_batch_norm: self.batch_norm = nn.BatchNorm1d(config.latent_dim) else: self.batch_norm = None # Dropout if config.dropout_rate > 0: self.dropout = nn.Dropout(config.dropout_rate) else: self.dropout = None def forward(self, z: torch.Tensor, target_length: int, target_sequence: Optional[torch.Tensor] = None) -> torch.Tensor: """ Forward pass through recurrent decoder. Args: z: Latent representation of shape (batch_size, latent_dim) target_length: Length of sequence to generate target_sequence: Target sequence for teacher forcing (optional) Returns: Decoded sequence of shape (batch_size, seq_len, input_dim) """ batch_size = z.size(0) device = z.device # Initialize hidden state with latent representation if self.config.rnn_type == "lstm": h_0 = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) c_0 = torch.zeros_like(h_0) hidden = (h_0, c_0) else: hidden = z.unsqueeze(0).repeat(self.config.num_layers, 1, 1) outputs = [] # Initialize input (can be learned or zero) current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) for t in range(target_length): # Teacher forcing decision use_teacher_forcing = (target_sequence is not None and self.training and random.random() < self.config.teacher_forcing_ratio) if use_teacher_forcing and t > 0: # Use previous target as input current_input = target_sequence[:, t-1:t, :] # Project to latent dimension if needed if current_input.size(-1) != self.config.latent_dim: current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) # RNN forward step if self.config.rnn_type == "lstm": output, hidden = self.rnn(current_input, hidden) else: output, hidden = self.rnn(current_input, hidden) # Apply batch normalization and dropout output_flat = output.squeeze(1) # Remove sequence dimension if self.batch_norm: output_flat = self.batch_norm(output_flat) if self.dropout and self.training: output_flat = self.dropout(output_flat) # Project to output dimension step_output = self.output_projection(output_flat) outputs.append(step_output.unsqueeze(1)) # Use output as next input (for non-teacher forcing) if not use_teacher_forcing: # Project output back to latent dimension for next step current_input = torch.zeros(batch_size, 1, self.config.latent_dim, device=device) # Concatenate all outputs return torch.cat(outputs, dim=1) class AutoencoderModel(PreTrainedModel): """ The bare Autoencoder Model transformer outputting raw hidden-states without any specific head on top. This model inherits from PreTrainedModel. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads etc.) This model is also a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. """ config_class = AutoencoderConfig base_model_prefix = "autoencoder" supports_gradient_checkpointing = False def __init__(self, config: AutoencoderConfig): super().__init__(config) self.config = config # Initialize learnable preprocessing if config.has_preprocessing: self.preprocessor = LearnablePreprocessor(config) else: self.preprocessor = None # Initialize encoder and decoder based on type if config.is_recurrent: self.encoder = RecurrentEncoder(config) self.decoder = RecurrentDecoder(config) else: self.encoder = AutoencoderEncoder(config) self.decoder = AutoencoderDecoder(config) # Tie weights if specified if config.tie_weights: self._tie_weights() # Initialize weights self.post_init() def _tie_weights(self): """Tie encoder and decoder weights (transpose relationship).""" # This is a simplified weight tying - in practice, you might want more sophisticated tying pass def get_input_embeddings(self): """Get input embeddings (not applicable for basic autoencoder).""" return None def set_input_embeddings(self, value): """Set input embeddings (not applicable for basic autoencoder).""" pass def forward( self, input_values: torch.Tensor, sequence_lengths: Optional[torch.Tensor] = None, target_length: Optional[int] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], AutoencoderOutput]: """ Forward pass through the autoencoder. Args: input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: - Standard: (batch_size, input_dim) - Recurrent: (batch_size, seq_len, input_dim) sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. target_length (int, optional): Target sequence length for recurrent decoder. output_hidden_states (bool, optional): Whether to return hidden states. return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. Returns: AutoencoderOutput or tuple: The model outputs. """ output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Apply learnable preprocessing preprocessing_loss = torch.tensor(0.0, device=input_values.device) if self.preprocessor is not None: input_values, preprocessing_loss = self.preprocessor(input_values, inverse=False) # Handle different autoencoder types if self.config.is_recurrent: # Recurrent autoencoder if sequence_lengths is not None: encoder_output = self.encoder(input_values, sequence_lengths) else: encoder_output = self.encoder(input_values) if self.config.is_variational: latent, mu, logvar = encoder_output self._mu = mu self._logvar = logvar else: latent = encoder_output self._mu = None self._logvar = None # Determine target length for decoder if target_length is None: if self.config.sequence_length is not None: target_length = self.config.sequence_length else: target_length = input_values.size(1) # Use input sequence length # Decode latent back to sequence space reconstructed = self.decoder(latent, target_length, input_values if self.training else None) else: # Standard autoencoder encoder_output = self.encoder(input_values) if self.config.is_variational: latent, mu, logvar = encoder_output self._mu = mu self._logvar = logvar else: latent = encoder_output self._mu = None self._logvar = None # Decode latent back to input space reconstructed = self.decoder(latent) # Apply inverse preprocessing to reconstruction if self.preprocessor is not None and self.config.learn_inverse_preprocessing: reconstructed, inverse_loss = self.preprocessor(reconstructed, inverse=True) preprocessing_loss += inverse_loss hidden_states = None if output_hidden_states: if self.config.is_variational: hidden_states = (latent, mu, logvar) else: hidden_states = (latent,) if not return_dict: return tuple(v for v in [latent, reconstructed, hidden_states] if v is not None) return AutoencoderOutput( last_hidden_state=latent, reconstructed=reconstructed, hidden_states=hidden_states, preprocessing_loss=preprocessing_loss, ) class AutoencoderForReconstruction(PreTrainedModel): """ Autoencoder Model with a reconstruction head on top for reconstruction tasks. This model inherits from PreTrainedModel and adds a reconstruction loss calculation. """ config_class = AutoencoderConfig base_model_prefix = "autoencoder" def __init__(self, config: AutoencoderConfig): super().__init__(config) self.config = config # Initialize the base autoencoder model self.autoencoder = AutoencoderModel(config) # Initialize weights self.post_init() def get_input_embeddings(self): """Get input embeddings.""" return self.autoencoder.get_input_embeddings() def set_input_embeddings(self, value): """Set input embeddings.""" self.autoencoder.set_input_embeddings(value) def _compute_reconstruction_loss( self, reconstructed: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: """Compute reconstruction loss based on the configured loss type.""" if self.config.reconstruction_loss == "mse": return F.mse_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "bce": return F.binary_cross_entropy_with_logits(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "l1": return F.l1_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "huber": return F.huber_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "smooth_l1": return F.smooth_l1_loss(reconstructed, target, reduction="mean") elif self.config.reconstruction_loss == "kl_div": return F.kl_div(F.log_softmax(reconstructed, dim=-1), F.softmax(target, dim=-1), reduction="mean") elif self.config.reconstruction_loss == "cosine": return 1 - F.cosine_similarity(reconstructed, target, dim=-1).mean() elif self.config.reconstruction_loss == "focal": return self._focal_loss(reconstructed, target) elif self.config.reconstruction_loss == "dice": return self._dice_loss(reconstructed, target) elif self.config.reconstruction_loss == "tversky": return self._tversky_loss(reconstructed, target) elif self.config.reconstruction_loss == "ssim": return self._ssim_loss(reconstructed, target) elif self.config.reconstruction_loss == "perceptual": return self._perceptual_loss(reconstructed, target) else: raise ValueError(f"Unknown reconstruction loss: {self.config.reconstruction_loss}") def _focal_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 1.0, gamma: float = 2.0) -> torch.Tensor: """Compute focal loss for handling class imbalance.""" ce_loss = F.mse_loss(pred, target, reduction="none") pt = torch.exp(-ce_loss) focal_loss = alpha * (1 - pt) ** gamma * ce_loss return focal_loss.mean() def _dice_loss(self, pred: torch.Tensor, target: torch.Tensor, smooth: float = 1e-6) -> torch.Tensor: """Compute Dice loss for segmentation-like tasks.""" pred_flat = pred.view(-1) target_flat = target.view(-1) intersection = (pred_flat * target_flat).sum() dice = (2.0 * intersection + smooth) / (pred_flat.sum() + target_flat.sum() + smooth) return 1 - dice def _tversky_loss(self, pred: torch.Tensor, target: torch.Tensor, alpha: float = 0.7, beta: float = 0.3, smooth: float = 1e-6) -> torch.Tensor: """Compute Tversky loss, a generalization of Dice loss.""" pred_flat = pred.view(-1) target_flat = target.view(-1) true_pos = (pred_flat * target_flat).sum() false_neg = (target_flat * (1 - pred_flat)).sum() false_pos = ((1 - target_flat) * pred_flat).sum() tversky = (true_pos + smooth) / (true_pos + alpha * false_neg + beta * false_pos + smooth) return 1 - tversky def _ssim_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute SSIM-based loss (simplified version).""" # Simplified SSIM for 1D data mu1 = pred.mean(dim=-1, keepdim=True) mu2 = target.mean(dim=-1, keepdim=True) sigma1_sq = ((pred - mu1) ** 2).mean(dim=-1, keepdim=True) sigma2_sq = ((target - mu2) ** 2).mean(dim=-1, keepdim=True) sigma12 = ((pred - mu1) * (target - mu2)).mean(dim=-1, keepdim=True) c1, c2 = 0.01, 0.03 ssim = ((2 * mu1 * mu2 + c1) * (2 * sigma12 + c2)) / ((mu1**2 + mu2**2 + c1) * (sigma1_sq + sigma2_sq + c2)) return 1 - ssim.mean() def _perceptual_loss(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """Compute perceptual loss (simplified version using feature differences).""" # For simplicity, use L2 loss on normalized features pred_norm = F.normalize(pred, p=2, dim=-1) target_norm = F.normalize(target, p=2, dim=-1) return F.mse_loss(pred_norm, target_norm) def forward( self, input_values: torch.Tensor, labels: Optional[torch.Tensor] = None, sequence_lengths: Optional[torch.Tensor] = None, target_length: Optional[int] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple[torch.Tensor], AutoencoderForReconstructionOutput]: """ Forward pass with reconstruction loss calculation. Args: input_values (torch.Tensor): Input tensor. Shape depends on autoencoder type: - Standard: (batch_size, input_dim) - Recurrent: (batch_size, seq_len, input_dim) labels (torch.Tensor, optional): Target tensor for reconstruction. If None, uses input_values. sequence_lengths (torch.Tensor, optional): Sequence lengths for recurrent AE. target_length (int, optional): Target sequence length for recurrent decoder. output_hidden_states (bool, optional): Whether to return hidden states. return_dict (bool, optional): Whether to return a ModelOutput instead of a plain tuple. Returns: AutoencoderForReconstructionOutput or tuple: The model outputs including loss. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # If no labels provided, use input as target (standard autoencoder) if labels is None: labels = input_values # Forward pass through autoencoder outputs = self.autoencoder( input_values=input_values, sequence_lengths=sequence_lengths, target_length=target_length, output_hidden_states=output_hidden_states, return_dict=True, ) reconstructed = outputs.reconstructed latent = outputs.last_hidden_state hidden_states = outputs.hidden_states # Compute reconstruction loss recon_loss = self._compute_reconstruction_loss(reconstructed, labels) # Add regularization losses based on autoencoder type total_loss = recon_loss # Add preprocessing loss if available if hasattr(outputs, 'preprocessing_loss') and outputs.preprocessing_loss is not None: total_loss += outputs.preprocessing_loss if self.config.is_variational and hasattr(self.autoencoder, '_mu') and self.autoencoder._mu is not None: # KL divergence loss for variational autoencoders kl_loss = -0.5 * torch.sum(1 + self.autoencoder._logvar - self.autoencoder._mu.pow(2) - self.autoencoder._logvar.exp()) kl_loss = kl_loss / (self.autoencoder._mu.size(0) * self.autoencoder._mu.size(1)) # Normalize by batch size and latent dim total_loss = recon_loss + self.config.beta * kl_loss elif self.config.is_sparse: # Sparsity loss for sparse autoencoders latent = outputs.last_hidden_state sparsity_loss = torch.mean(torch.abs(latent)) # L1 sparsity total_loss = recon_loss + 0.1 * sparsity_loss # Sparsity weight elif self.config.is_contractive: # Contractive loss - penalize large gradients of hidden representation w.r.t. input latent = outputs.last_hidden_state latent.retain_grad() if latent.grad is not None: contractive_loss = torch.sum(latent.grad ** 2) total_loss = recon_loss + 0.1 * contractive_loss loss = total_loss if not return_dict: output = (reconstructed, latent) if hidden_states is not None: output = output + (hidden_states,) return ((loss,) + output) if loss is not None else output return AutoencoderForReconstructionOutput( loss=loss, reconstructed=reconstructed, last_hidden_state=latent, hidden_states=hidden_states, preprocessing_loss=outputs.preprocessing_loss if hasattr(outputs, 'preprocessing_loss') else None, )