import torch import torch.nn as nn from .Modules.conformer import ConformerEncoder, ConformerDecoder from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding from .kan.fasterkan import FasterKAN import numpy as np import xgboost as xgb import pandas as pd class Sine(nn.Module): def __init__(self, w0=1.0): super().__init__() self.w0 = w0 def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.sin(self.w0 * x) class MLPEncoder(nn.Module): def __init__(self, args): """ Initialize an MLP with hidden layers, BatchNorm, and Dropout. Args: input_dim (int): Dimension of the input features. hidden_dims (list of int): List of dimensions for hidden layers. output_dim (int): Dimension of the output. dropout (float): Dropout probability (default: 0.0). """ super(MLPEncoder, self).__init__() layers = [] prev_dim = args.input_dim # Add hidden layers for hidden_dim in args.hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.BatchNorm1d(hidden_dim)) layers.append(nn.SiLU()) if args.dropout > 0.0: layers.append(nn.Dropout(args.dropout)) prev_dim = hidden_dim self.model = nn.Sequential(*layers) self.output_dim = hidden_dim def forward(self, x): # if x.dim() == 2: # x = x.unsqueeze(-1) x = self.model(x) # x = x.mean(-1) return x class ConvBlock(nn.Module): def __init__(self, args, num_layer) -> None: super().__init__() if args.activation == 'silu': self.activation = nn.SiLU() elif args.activation == 'sine': self.activation = Sine(w0=args.sine_w0) else: self.activation = nn.ReLU() in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1] out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1] self.layers = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=args.kernel_size, stride=1, padding='same', bias=False), nn.BatchNorm1d(num_features=out_channels), self.activation, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layers(x) class CNNEncoder(nn.Module): def __init__(self, args) -> None: super().__init__() print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output) if args.activation == 'silu': self.activation = nn.SiLU() elif args.activation == 'sine': self.activation = Sine(w0=args.sine_w0) else: self.activation = nn.ReLU() self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels, kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False), nn.BatchNorm1d(args.encoder_dims[0]), self.activation, ) self.layers = nn.ModuleList([ConvBlock(args, i+1) for i in range(args.num_layers)]) self.pool = nn.MaxPool1d(2) self.output_dim = args.encoder_dims[-1] self.min_seq_len = 2 self.avg_output = args.avg_output def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape)==2: x = x.unsqueeze(1) if len(x.shape)==3 and x.shape[-1]==1: x = x.permute(0,2,1) x = self.embedding(x) for m in self.layers: x = m(x) if x.shape[-1] > self.min_seq_len: x = self.pool(x) if self.avg_output: x = x.mean(dim=-1) return x class MultiEncoder(nn.Module): def __init__(self, args, conformer_args): super().__init__() self.backbone = CNNEncoder(args) self.backbone.avg_output = False self.head_size = conformer_args.encoder_dim // conformer_args.num_heads self.rotary_ndims = int(self.head_size * 0.5) self.pe = RotaryEmbedding(self.rotary_ndims) self.encoder = ConformerEncoder(conformer_args) self.output_dim = conformer_args.encoder_dim self.avg_output = args.avg_output def forward(self, x): # Store backbone output in a separate tensor backbone_out = self.backbone(x) # Create x_enc from backbone_out if len(backbone_out.shape) == 2: x_enc = backbone_out.unsqueeze(1).clone() else: x_enc = backbone_out.permute(0,2,1).clone() RoPE = self.pe(x_enc, x_enc.shape[1]) x_enc = self.encoder(x_enc, RoPE) if len(x_enc.shape) == 3: if self.avg_output: x_enc = x_enc.sum(dim=1) else: x_enc = x_enc.permute(0,2,1) # Return x_enc and the original backbone output return x_enc, backbone_out class DualEncoder(nn.Module): def __init__(self, args_x, args_f, conformer_args) -> None: super().__init__() self.encoder_x = CNNEncoder(args_x) self.encoder_f = MultiEncoder(args_f, conformer_args) total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1] self.regressor = nn.Sequential( nn.Linear(total_output_dim, total_output_dim//2), nn.BatchNorm1d(total_output_dim//2), nn.SiLU(), nn.Linear(total_output_dim//2, 1) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.encoder_x(x) x2, _ = self.encoder_f(x) logits = torch.cat([x1, x2], dim=-1) return self.regressor(logits).squeeze() class CNNFeaturesEncoder(nn.Module): def __init__(self, xgb_model, args, mlp_hidden=64): super().__init__() self.xgb_model = xgb_model self.best_xgb_features = xgb_model.best_iteration + 1 self.backbone = CNNEncoder(args) self.total_features = self.best_xgb_features + args.encoder_dims[-1] self.mlp = nn.Sequential( nn.Linear(self.total_features, mlp_hidden), nn.BatchNorm1d(mlp_hidden), nn.SiLU(), nn.Linear(mlp_hidden, mlp_hidden), nn.BatchNorm1d(mlp_hidden), nn.SiLU(), nn.Linear(mlp_hidden, 1), ) def _create_features_data(self, features): # Handle batch processing batch_size = len(features) data = [] # Iterate through each item in the batch for batch_idx in range(batch_size): feature_dict = {} for k, v in features[batch_idx].items(): feature_dict[f"frequency_domain_{k}"] = v[0].item() data.append(feature_dict) return pd.DataFrame(data) def forward(self, x: torch.Tensor, f) -> torch.Tensor: x = self.backbone(x) x = x.mean(dim=-1) f_np = self._create_features_data(f) dtest = xgb.DMatrix(f_np) # Convert input to DMatrix xgb_features = self.xgb_model.predict(dtest, pred_leaf=True).astype(np.float32) xgb_features = torch.tensor(xgb_features, dtype=torch.float32, device=x.device) x_f = torch.cat([x, xgb_features[:, :self.best_xgb_features]], dim=1) return self.mlp(x_f) class CNNKan(nn.Module): def __init__(self, args, conformer_args, kan_args): super().__init__() self.backbone = CNNEncoder(args) # self.kan = KAN(width=kan_args['layers_hidden']) self.kan = FasterKAN(**kan_args) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.backbone(x) x = x.mean(dim=1) return self.kan(x) class CNNKanFeaturesEncoder(nn.Module): def __init__(self, xgb_model, args, kan_args): super().__init__() self.xgb_model = xgb_model self.best_xgb_features = xgb_model.best_iteration + 1 self.backbone = CNNEncoder(args) kan_args['layers_hidden'][0] += self.best_xgb_features self.kan = FasterKAN(**kan_args) def _create_features_data(self, features): # Handle batch processing batch_size = len(features) data = [] # Iterate through each item in the batch for batch_idx in range(batch_size): feature_dict = {} for k, v in features[batch_idx].items(): feature_dict[f"{k}"] = v[0].item() data.append(feature_dict) return pd.DataFrame(data) def forward(self, x: torch.Tensor, f) -> torch.Tensor: x = self.backbone(x) x = x.mean(dim=1) f_np = self._create_features_data(f) dtest = xgb.DMatrix(f_np) # Convert input to DMatrix xgb_features = self.xgb_model.predict(dtest, pred_leaf=True).astype(np.float32) xgb_features = torch.tensor(xgb_features, dtype=torch.float32, device=x.device) x_f = torch.cat([x, xgb_features[:, :self.best_xgb_features]], dim=1) return self.kan(x_f) class KanEncoder(nn.Module): def __init__(self, args): super().__init__() self.kan_x = FasterKAN(**args) self.kan_f = FasterKAN(**args) self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1]) def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor: x = self.kan_x(x) f = self.kan_f(f) out = torch.cat([x, f], dim=-1) return self.kan_out(out) class MultiGraph(nn.Module): def __init__(self, graph_net, args): super().__init__() self.graph_net = graph_net self.cnn = CNNEncoder(args) total_output_dim = args.encoder_dims[-1] self.projection = nn.Sequential( nn.Linear(total_output_dim, total_output_dim // 2), nn.BatchNorm1d(total_output_dim // 2), nn.SiLU(), nn.Linear(total_output_dim // 2, 1) ) def forward(self, g: torch.Tensor, x:torch.Tensor) -> torch.Tensor: # g_out = self.graph_net(g) x_out = self.cnn(x) # g_out = g_out.expand(x.shape[0], -1) # features = torch.cat([g_out, x_out], dim=-1) return self.projection(x_out) class ImplicitEncoder(nn.Module): def __init__(self, transform_net, encoder_net): super().__init__() self.transform_net = transform_net self.encoder_net = encoder_net def get_weights_and_bises(self): state_dict = self.transform_net.state_dict() weights = tuple( [v.permute(1, 0).unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "weight" in w] ) biases = tuple([v.unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "bias" in w]) return weights, biases def forward(self, x: torch.Tensor) -> torch.Tensor: transformed_x = self.transform_net(x.permute(0, 2, 1)).permute(0, 2, 1) inputs = self.get_weights_and_bises() outputs = self.encoder_net(inputs, transformed_x) return outputs