IlayMalinyak
kan
49ebc1f
raw
history blame
5.26 kB
import torch
import torch.nn as nn
from .Modules.conformer import ConformerEncoder, ConformerDecoder
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
from .kan.fasterkan import FasterKAN
from kan import KAN
class ConvBlock(nn.Module):
def __init__(self, args, num_layer) -> None:
super().__init__()
if args.activation == 'silu':
self.activation = nn.SiLU()
else:
self.activation = nn.ReLU()
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
self.layers = nn.Sequential(
nn.Conv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=args.kernel_size,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=out_channels),
self.activation,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class CNNEncoder(nn.Module):
def __init__(self, args) -> None:
super().__init__()
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
if args.activation == 'silu':
self.activation = nn.SiLU()
else:
self.activation = nn.ReLU()
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
nn.BatchNorm1d(args.encoder_dims[0]),
self.activation,
)
self.layers = nn.ModuleList([ConvBlock(args, i+1)
for i in range(args.num_layers)])
self.pool = nn.MaxPool1d(2)
self.output_dim = args.encoder_dims[-1]
self.min_seq_len = 2
self.avg_output = args.avg_output
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape)==2:
x = x.unsqueeze(1)
if len(x.shape)==3 and x.shape[-1]==1:
x = x.permute(0,2,1)
x = self.embedding(x)
for m in self.layers:
x = m(x)
if x.shape[-1] > self.min_seq_len:
x = self.pool(x)
if self.avg_output:
x = x.mean(dim=-1)
return x
class MultiEncoder(nn.Module):
def __init__(self, args, conformer_args):
super().__init__()
self.backbone = CNNEncoder(args)
self.backbone.avg_output = False
self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
self.rotary_ndims = int(self.head_size * 0.5)
self.pe = RotaryEmbedding(self.rotary_ndims)
self.encoder = ConformerEncoder(conformer_args)
self.output_dim = conformer_args.encoder_dim
self.avg_output = args.avg_output
def forward(self, x):
# Store backbone output in a separate tensor
backbone_out = self.backbone(x)
# Create x_enc from backbone_out
if len(backbone_out.shape) == 2:
x_enc = backbone_out.unsqueeze(1).clone()
else:
x_enc = backbone_out.permute(0,2,1).clone()
RoPE = self.pe(x_enc, x_enc.shape[1])
x_enc = self.encoder(x_enc, RoPE)
if len(x_enc.shape) == 3:
if self.avg_output:
x_enc = x_enc.sum(dim=1)
else:
x_enc = x_enc.permute(0,2,1)
# Return x_enc and the original backbone output
return x_enc, backbone_out
class DualEncoder(nn.Module):
def __init__(self, args_x, args_f, conformer_args) -> None:
super().__init__()
self.encoder_x = CNNEncoder(args_x)
self.encoder_f = MultiEncoder(args_f, conformer_args)
total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
self.regressor = nn.Sequential(
nn.Linear(total_output_dim, total_output_dim//2),
nn.BatchNorm1d(total_output_dim//2),
nn.SiLU(),
nn.Linear(total_output_dim//2, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.encoder_x(x)
x2, _ = self.encoder_f(x)
logits = torch.cat([x1, x2], dim=-1)
return self.regressor(logits).squeeze()
class CNNKan(nn.Module):
def __init__(self, args, conformer_args, kan_args):
super().__init__()
self.backbone = CNNEncoder(args)
# self.kan = KAN(width=kan_args['layers_hidden'])
self.kan = FasterKAN(**kan_args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.backbone(x)
x = x.mean(dim=1)
return self.kan(x)
class KanEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.kan_x = FasterKAN(**args)
self.kan_f = FasterKAN(**args)
self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
x = self.kan_x(x)
f = self.kan_f(f)
out = torch.cat([x, f], dim=-1)
return self.kan_out(out)