IlayMalinyak
tested locally
a79c5f2
import torch
import torch.nn as nn
from .Modules.conformer import ConformerEncoder, ConformerDecoder
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
from .kan.fasterkan import FasterKAN
import numpy as np
import xgboost as xgb
import pandas as pd
class Sine(nn.Module):
def __init__(self, w0=1.0):
super().__init__()
self.w0 = w0
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.sin(self.w0 * x)
class MLPEncoder(nn.Module):
def __init__(self, args):
"""
Initialize an MLP with hidden layers, BatchNorm, and Dropout.
Args:
input_dim (int): Dimension of the input features.
hidden_dims (list of int): List of dimensions for hidden layers.
output_dim (int): Dimension of the output.
dropout (float): Dropout probability (default: 0.0).
"""
super(MLPEncoder, self).__init__()
layers = []
prev_dim = args.input_dim
# Add hidden layers
for hidden_dim in args.hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.SiLU())
if args.dropout > 0.0:
layers.append(nn.Dropout(args.dropout))
prev_dim = hidden_dim
self.model = nn.Sequential(*layers)
self.output_dim = hidden_dim
def forward(self, x):
# if x.dim() == 2:
# x = x.unsqueeze(-1)
x = self.model(x)
# x = x.mean(-1)
return x
class ConvBlock(nn.Module):
def __init__(self, args, num_layer) -> None:
super().__init__()
if args.activation == 'silu':
self.activation = nn.SiLU()
elif args.activation == 'sine':
self.activation = Sine(w0=args.sine_w0)
else:
self.activation = nn.ReLU()
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
out_channels = args.encoder_dims[num_layer] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
self.layers = nn.Sequential(
nn.Conv1d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=args.kernel_size,
stride=1, padding='same', bias=False),
nn.BatchNorm1d(num_features=out_channels),
self.activation,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.layers(x)
class CNNEncoder(nn.Module):
def __init__(self, args) -> None:
super().__init__()
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
if args.activation == 'silu':
self.activation = nn.SiLU()
elif args.activation == 'sine':
self.activation = Sine(w0=args.sine_w0)
else:
self.activation = nn.ReLU()
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
kernel_size=3, out_channels = args.encoder_dims[0], stride=1, padding = 'same', bias = False),
nn.BatchNorm1d(args.encoder_dims[0]),
self.activation,
)
self.layers = nn.ModuleList([ConvBlock(args, i+1)
for i in range(args.num_layers)])
self.pool = nn.MaxPool1d(2)
self.output_dim = args.encoder_dims[-1]
self.min_seq_len = 2
self.avg_output = args.avg_output
def forward(self, x: torch.Tensor) -> torch.Tensor:
if len(x.shape)==2:
x = x.unsqueeze(1)
if len(x.shape)==3 and x.shape[-1]==1:
x = x.permute(0,2,1)
x = self.embedding(x)
for m in self.layers:
x = m(x)
if x.shape[-1] > self.min_seq_len:
x = self.pool(x)
if self.avg_output:
x = x.mean(dim=-1)
return x
class MultiEncoder(nn.Module):
def __init__(self, args, conformer_args):
super().__init__()
self.backbone = CNNEncoder(args)
self.backbone.avg_output = False
self.head_size = conformer_args.encoder_dim // conformer_args.num_heads
self.rotary_ndims = int(self.head_size * 0.5)
self.pe = RotaryEmbedding(self.rotary_ndims)
self.encoder = ConformerEncoder(conformer_args)
self.output_dim = conformer_args.encoder_dim
self.avg_output = args.avg_output
def forward(self, x):
# Store backbone output in a separate tensor
backbone_out = self.backbone(x)
# Create x_enc from backbone_out
if len(backbone_out.shape) == 2:
x_enc = backbone_out.unsqueeze(1).clone()
else:
x_enc = backbone_out.permute(0,2,1).clone()
RoPE = self.pe(x_enc, x_enc.shape[1])
x_enc = self.encoder(x_enc, RoPE)
if len(x_enc.shape) == 3:
if self.avg_output:
x_enc = x_enc.sum(dim=1)
else:
x_enc = x_enc.permute(0,2,1)
# Return x_enc and the original backbone output
return x_enc, backbone_out
class DualEncoder(nn.Module):
def __init__(self, args_x, args_f, conformer_args) -> None:
super().__init__()
self.encoder_x = CNNEncoder(args_x)
self.encoder_f = MultiEncoder(args_f, conformer_args)
total_output_dim = args_x.encoder_dims[-1] + args_f.encoder_dims[-1]
self.regressor = nn.Sequential(
nn.Linear(total_output_dim, total_output_dim//2),
nn.BatchNorm1d(total_output_dim//2),
nn.SiLU(),
nn.Linear(total_output_dim//2, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.encoder_x(x)
x2, _ = self.encoder_f(x)
logits = torch.cat([x1, x2], dim=-1)
return self.regressor(logits).squeeze()
class CNNFeaturesEncoder(nn.Module):
def __init__(self, xgb_model, args, mlp_hidden=64):
super().__init__()
self.xgb_model = xgb_model
self.best_xgb_features = xgb_model.best_iteration + 1
self.backbone = CNNEncoder(args)
self.total_features = self.best_xgb_features + args.encoder_dims[-1]
self.mlp = nn.Sequential(
nn.Linear(self.total_features, mlp_hidden),
nn.BatchNorm1d(mlp_hidden),
nn.SiLU(),
nn.Linear(mlp_hidden, mlp_hidden),
nn.BatchNorm1d(mlp_hidden),
nn.SiLU(),
nn.Linear(mlp_hidden, 1),
)
def _create_features_data(self, features):
# Handle batch processing
batch_size = len(features)
data = []
# Iterate through each item in the batch
for batch_idx in range(batch_size):
feature_dict = {}
for k, v in features[batch_idx].items():
feature_dict[f"frequency_domain_{k}"] = v[0].item()
data.append(feature_dict)
return pd.DataFrame(data)
def forward(self, x: torch.Tensor, f) -> torch.Tensor:
x = self.backbone(x)
x = x.mean(dim=-1)
f_np = self._create_features_data(f)
dtest = xgb.DMatrix(f_np) # Convert input to DMatrix
xgb_features = self.xgb_model.predict(dtest, pred_leaf=True).astype(np.float32)
xgb_features = torch.tensor(xgb_features, dtype=torch.float32, device=x.device)
x_f = torch.cat([x, xgb_features[:, :self.best_xgb_features]], dim=1)
return self.mlp(x_f)
class CNNKan(nn.Module):
def __init__(self, args, conformer_args, kan_args):
super().__init__()
self.backbone = CNNEncoder(args)
# self.kan = KAN(width=kan_args['layers_hidden'])
self.kan = FasterKAN(**kan_args)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.backbone(x)
x = x.mean(dim=1)
return self.kan(x)
class CNNKanFeaturesEncoder(nn.Module):
def __init__(self, xgb_model, args, kan_args):
super().__init__()
self.xgb_model = xgb_model
self.best_xgb_features = xgb_model.best_iteration + 1
self.backbone = CNNEncoder(args)
kan_args['layers_hidden'][0] += self.best_xgb_features
self.kan = FasterKAN(**kan_args)
def _create_features_data(self, features):
# Handle batch processing
batch_size = len(features)
data = []
# Iterate through each item in the batch
for batch_idx in range(batch_size):
feature_dict = {}
for k, v in features[batch_idx].items():
feature_dict[f"{k}"] = v[0].item()
data.append(feature_dict)
return pd.DataFrame(data)
def forward(self, x: torch.Tensor, f) -> torch.Tensor:
x = self.backbone(x)
x = x.mean(dim=1)
f_np = self._create_features_data(f)
dtest = xgb.DMatrix(f_np) # Convert input to DMatrix
xgb_features = self.xgb_model.predict(dtest, pred_leaf=True).astype(np.float32)
xgb_features = torch.tensor(xgb_features, dtype=torch.float32, device=x.device)
x_f = torch.cat([x, xgb_features[:, :self.best_xgb_features]], dim=1)
return self.kan(x_f)
class KanEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.kan_x = FasterKAN(**args)
self.kan_f = FasterKAN(**args)
self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
x = self.kan_x(x)
f = self.kan_f(f)
out = torch.cat([x, f], dim=-1)
return self.kan_out(out)
class MultiGraph(nn.Module):
def __init__(self, graph_net, args):
super().__init__()
self.graph_net = graph_net
self.cnn = CNNEncoder(args)
total_output_dim = args.encoder_dims[-1]
self.projection = nn.Sequential(
nn.Linear(total_output_dim, total_output_dim // 2),
nn.BatchNorm1d(total_output_dim // 2),
nn.SiLU(),
nn.Linear(total_output_dim // 2, 1)
)
def forward(self, g: torch.Tensor, x:torch.Tensor) -> torch.Tensor:
# g_out = self.graph_net(g)
x_out = self.cnn(x)
# g_out = g_out.expand(x.shape[0], -1)
# features = torch.cat([g_out, x_out], dim=-1)
return self.projection(x_out)
class ImplicitEncoder(nn.Module):
def __init__(self, transform_net, encoder_net):
super().__init__()
self.transform_net = transform_net
self.encoder_net = encoder_net
def get_weights_and_bises(self):
state_dict = self.transform_net.state_dict()
weights = tuple(
[v.permute(1, 0).unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "weight" in w]
)
biases = tuple([v.unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "bias" in w])
return weights, biases
def forward(self, x: torch.Tensor) -> torch.Tensor:
transformed_x = self.transform_net(x.permute(0, 2, 1)).permute(0, 2, 1)
inputs = self.get_weights_and_bises()
outputs = self.encoder_net(inputs, transformed_x)
return outputs