submission-template / tasks /utils /kan /fasterkan_layers.py
IlayMalinyak
kan
49ebc1f
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
from torch.autograd import Function
from .fasterkan_basis import ReflectionalSwitchFunction, SplineLinear
class FasterKANLayer(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
grid_min: float = -1.2,
grid_max: float = 0.2,
num_grids: int = 8,
exponent: int = 2,
inv_denominator: float = 0.5,
train_grid: bool = False,
train_inv_denominator: bool = False,
#use_base_update: bool = True,
base_activation = F.silu,
spline_weight_init_scale: float = 0.667,
) -> None:
super().__init__()
self.layernorm = nn.LayerNorm(input_dim)
self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, inv_denominator, train_grid, train_inv_denominator)
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
#self.use_base_update = use_base_update
#if use_base_update:
# self.base_activation = base_activation
# self.base_linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
#print("Shape before LayerNorm:", x.shape) # Debugging line to check the input shape
x = self.layernorm(x)
#print("Shape After LayerNorm:", x.shape)
spline_basis = self.rbf(x).view(x.shape[0], -1)
#print("spline_basis:", spline_basis.shape)
#print("-------------------------")
#ret = 0
ret = self.spline_linear(spline_basis)
#print("spline_basis.shape[:-2]:", spline_basis.shape[:-2])
#print("*spline_basis.shape[:-2]:", *spline_basis.shape[:-2])
#print("spline_basis.view(*spline_basis.shape[:-2], -1):", spline_basis.view(*spline_basis.shape[:-2], -1).shape)
#print("ret:", ret.shape)
#print("-------------------------")
#if self.use_base_update:
#base = self.base_linear(self.base_activation(x))
#print("self.base_activation(x):", self.base_activation(x).shape)
#print("base:", base.shape)
#print("@@@@@@@@@")
#ret += base
return ret
#spline_basis = spline_basis.reshape(x.shape[0], -1) # Reshape to [batch_size, input_dim * num_grids]
#print("spline_basis:", spline_basis.shape)
#spline_weight = self.spline_weight.view(-1, self.spline_weight.shape[0]) # Reshape to [input_dim * num_grids, output_dim]
#print("spline_weight:", spline_weight.shape)
#spline = torch.matmul(spline_basis, spline_weight) # Resulting shape: [batch_size, output_dim]
#print("-------------------------")
#print("Base shape:", base.shape)
#print("Spline shape:", spline.shape)
#print("@@@@@@@@@")
class FasterKAN(nn.Module):
def __init__(
self,
layers_hidden: List[int],
grid_min: float = -1.2,
grid_max: float = 0.2,
num_grids: int = 8,
exponent: int = 2,
inv_denominator: float = 0.5,
train_grid: bool = False,
train_inv_denominator: bool = False,
#use_base_update: bool = True,
base_activation = None,
spline_weight_init_scale: float = 1.0,
) -> None:
super().__init__()
self.layers = nn.ModuleList([
FasterKANLayer(
in_dim, out_dim,
grid_min=grid_min,
grid_max=grid_max,
num_grids=num_grids,
exponent = exponent,
inv_denominator = inv_denominator,
train_grid = train_grid ,
train_inv_denominator = train_inv_denominator,
#use_base_update=use_base_update,
base_activation=base_activation,
spline_weight_init_scale=spline_weight_init_scale,
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
#print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
#print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
#print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
#print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
#print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
def forward(self, x):
for layer in self.layers:
#print("FasterKAN layer: \n", layer)
#print(f"FasterKAN x shape: {x.shape}")
x = layer(x)
return x
class BasicResBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(BasicResBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.downsample = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.downsample(x)
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += identity
out = F.relu(out)
return out
class SEBlock(nn.Module):
def __init__(self, channel, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
super(DepthwiseSeparableConv, self).__init__()
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
stride=stride, padding=padding, groups=in_channels)
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class SelfAttention(nn.Module):
def __init__(self, in_channels):
super(SelfAttention, self).__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, C, width, height = x.size()
proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
proj_key = self.key_conv(x).view(batch_size, -1, width * height)
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=-1)
proj_value = self.value_conv(x).view(batch_size, -1, width * height)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(batch_size, C, width, height)
out = self.gamma * out + x
return out
class EnhancedFeatureExtractor(nn.Module):
def __init__(self):
super(EnhancedFeatureExtractor, self).__init__()
self.initial_layers = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # Increased number of filters
nn.ReLU(),
nn.BatchNorm2d(32), # Added Batch Normalization
nn.MaxPool2d(2, 2),
nn.Dropout(0.25), # Added Dropout
BasicResBlock(32, 64),
SEBlock(64, reduction=16), # Squeeze-and-Excitation block
nn.MaxPool2d(2, 2),
nn.Dropout(0.25), # Added Dropout
DepthwiseSeparableConv(64, 128, kernel_size=3), # Increased number of filters
nn.ReLU(),
BasicResBlock(128, 256),
SEBlock(256, reduction=16),
nn.MaxPool2d(2, 2),
nn.Dropout(0.25), # Added Dropout
SelfAttention(256), # Added Self-Attention layer
)
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Global Average Pooling
def forward(self, x):
x = self.initial_layers(x)
x = self.global_avg_pool(x)
x = x.view(x.size(0), -1) # Flatten the output for fully connected layers
return x
class FasterKANvolver(nn.Module):
def __init__(
self,
layers_hidden: List[int],
grid_min: float = -1.2,
grid_max: float = 0.2,
num_grids: int = 8,
exponent: int = 2,
inv_denominator: float = 0.5,
train_grid: bool = False,
train_inv_denominator: bool = False,
#use_base_update: bool = True,
base_activation = None,
spline_weight_init_scale: float = 1.0,
) -> None:
super(FasterKANvolver, self).__init__()
# Feature extractor with Convolutional layers
self.feature_extractor = EnhancedFeatureExtractor()
"""
nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
"""
# Calculate the flattened feature size after convolutional layers
flat_features = 256 # XX channels, image size reduced to YxY
# Update layers_hidden with the correct input size from conv layers
layers_hidden = [flat_features] + layers_hidden
#print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
#print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
#print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
#print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
# Define the FasterKAN layers
self.faster_kan_layers = nn.ModuleList([
FasterKANLayer(
in_dim, out_dim,
grid_min=grid_min,
grid_max=grid_max,
num_grids=num_grids,
exponent=exponent,
inv_denominator = 0.5,
train_grid = False,
train_inv_denominator = False,
#use_base_update=use_base_update,
base_activation=base_activation,
spline_weight_init_scale=spline_weight_init_scale,
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
#print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
#print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
def forward(self, x):
# Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
#print(f"FasterKAN x view shape: {x.shape}")
x = x.view(-1, 3, 32,32)
#print(f"FasterKAN x view shape: {x.shape}")
# Apply convolutional layers
#print(f"FasterKAN x view shape: {x.shape}")
x = self.feature_extractor(x)
#print(f"FasterKAN x after feature_extractor shape: {x.shape}")
x = x.view(x.size(0), -1) # Flatten the output from the conv layers
#rint(f"FasterKAN x shape: {x.shape}")
# Pass through FasterKAN layers
for layer in self.faster_kan_layers:
#print("FasterKAN layer: \n", layer)
#print(f"FasterKAN x shape: {x.shape}")
x = layer(x)
#print(f"FasterKAN x shape: {x.shape}")
return x