import torch import torch.nn as nn from .Modules.conformer import ConformerEncoder, ConformerDecoder from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding from .kan.fasterkan import FasterKAN from kan import KAN class ConvBlock(nn.Module): def __init__(self, args, num_layer) -> None: super().__init__() if args.activation == 'silu': self.activation = nn.SiLU() else: self.activation = nn.ReLU() in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1] out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1] self.layers = nn.Sequential( nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=args.kernel_size, stride=1, padding='same', bias=False), nn.BatchNorm1d(num_features=out_channels), self.activation, ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.layers(x) class CNNEncoder(nn.Module): def __init__(self, args) -> None: super().__init__() print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output) if args.activation == 'silu': self.activation = nn.SiLU() else: self.activation = nn.ReLU() self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels, kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False), nn.BatchNorm1d(args.encoder_dims[0]), self.activation, ) self.layers = nn.ModuleList([ConvBlock(args, i+1) for i in range(args.num_layers)]) self.pool = nn.MaxPool1d(2) self.output_dim = args.encoder_dims[-1] self.min_seq_len = 2 self.avg_output = args.avg_output def forward(self, x: torch.Tensor) -> torch.Tensor: if len(x.shape)==2: x = x.unsqueeze(1) if len(x.shape)==3 and x.shape[-1]==1: x = x.permute(0,2,1) x = self.embedding(x) for m in self.layers: x = m(x) if x.shape[-1] > self.min_seq_len: x = self.pool(x) if self.avg_output: x = x.mean(dim=-1) return x class MultiEncoder(nn.Module): def __init__(self, args, conformer_args): super().__init__() self.backbone = CNNEncoder(args) self.backbone.avg_output = False self.head_size = conformer_args.encoder_dim // conformer_args.num_heads self.rotary_ndims = int(self.head_size * 0.5) self.pe = RotaryEmbedding(self.rotary_ndims) self.encoder = ConformerEncoder(conformer_args) self.output_dim = conformer_args.encoder_dim self.avg_output = args.avg_output def forward(self, x): # Store backbone output in a separate tensor backbone_out = self.backbone(x) # Create x_enc from backbone_out if len(backbone_out.shape) == 2: x_enc = backbone_out.unsqueeze(1).clone() else: x_enc = backbone_out.permute(0,2,1).clone() RoPE = self.pe(x_enc, x_enc.shape[1]) x_enc = self.encoder(x_enc, RoPE) if len(x_enc.shape) == 3: if self.avg_output: x_enc = x_enc.sum(dim=1) else: x_enc = x_enc.permute(0,2,1) # Return x_enc and the original backbone output return x_enc, backbone_out class DualEncoder(nn.Module): def __init__(self, args_x, args_f, conformer_args) -> None: super().__init__() self.encoder_x = CNNEncoder(args_x) self.encoder_f = MultiEncoder(args_f, conformer_args) total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1] self.regressor = nn.Sequential( nn.Linear(total_output_dim, total_output_dim//2), nn.BatchNorm1d(total_output_dim//2), nn.SiLU(), nn.Linear(total_output_dim//2, 1) ) def forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.encoder_x(x) x2, _ = self.encoder_f(x) logits = torch.cat([x1, x2], dim=-1) return self.regressor(logits).squeeze() class CNNKan(nn.Module): def __init__(self, args, conformer_args, kan_args): super().__init__() self.backbone = CNNEncoder(args) # self.kan = KAN(width=kan_args['layers_hidden']) self.kan = FasterKAN(**kan_args) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.backbone(x) x = x.mean(dim=1) return self.kan(x) class KanEncoder(nn.Module): def __init__(self, args): super().__init__() self.kan_x = FasterKAN(**args) self.kan_f = FasterKAN(**args) self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1]) def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor: x = self.kan_x(x) f = self.kan_f(f) out = torch.cat([x, f], dim=-1) return self.kan_out(out)