Spaces:
Sleeping
Sleeping
File size: 5,683 Bytes
49ebc1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
from torch.autograd import Function
from .feature_extractor import EnhancedFeatureExtractor
from .fasterkan_layers import FasterKANLayer
class FasterKAN(nn.Module):
def __init__(
self,
layers_hidden: List[int],
grid_min: float = -1.2,
grid_max: float = 1.2,
num_grids: int = 8,
exponent: int = 2,
inv_denominator: float = 0.5,
train_grid: bool = False,
train_inv_denominator: bool = False,
#use_base_update: bool = True,
base_activation = None,
spline_weight_init_scale: float = 1.0,
) -> None:
super().__init__()
self.layers = nn.ModuleList([
FasterKANLayer(
in_dim, out_dim,
grid_min=grid_min,
grid_max=grid_max,
num_grids=num_grids,
exponent = exponent,
inv_denominator = inv_denominator,
train_grid = train_grid ,
train_inv_denominator = train_inv_denominator,
#use_base_update=use_base_update,
base_activation=base_activation,
spline_weight_init_scale=spline_weight_init_scale,
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
#print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
#print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
#print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
#print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
#print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
def forward(self, x):
for layer in self.layers:
#print("FasterKAN layer: \n", layer)
#print(f"FasterKAN x shape: {x.shape}")
x = layer(x)
return x
class FasterKANvolver(nn.Module):
def __init__(
self,
layers_hidden: List[int],
grid_min: float = -1.2,
grid_max: float = 0.2,
num_grids: int = 8,
exponent: int = 2,
inv_denominator: float = 0.5,
train_grid: bool = False,
train_inv_denominator: bool = False,
#use_base_update: bool = True,
base_activation = None,
spline_weight_init_scale: float = 1.0,
view = [-1, 1, 28, 28],
) -> None:
super(FasterKANvolver, self).__init__()
self.view = view
# Feature extractor with Convolutional layers
self.feature_extractor = EnhancedFeatureExtractor(colors = view[1])
"""
nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
"""
# Calculate the flattened feature size after convolutional layers
flat_features = 256 # XX channels, image size reduced to YxY
# Update layers_hidden with the correct input size from conv layers
layers_hidden = [flat_features] + layers_hidden
#print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
#print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
#print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
#print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
# Define the FasterKAN layers
self.faster_kan_layers = nn.ModuleList([
FasterKANLayer(
in_dim, out_dim,
grid_min=grid_min,
grid_max=grid_max,
num_grids=num_grids,
exponent=exponent,
inv_denominator = 0.5,
train_grid = False,
train_inv_denominator = False,
#use_base_update=use_base_update,
base_activation=base_activation,
spline_weight_init_scale=spline_weight_init_scale,
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
#print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
#print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
def forward(self, x):
# Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
#print(f"FasterKAN x view shape: {x.shape}")
# Handle different input shapes based on the length of view
x = x.view(self.view[0], self.view[1], self.view[2], self.view[3])
#print(f"FasterKAN x view shape: {x.shape}")
# Apply convolutional layers
#print(f"FasterKAN x view shape: {x.shape}")
x = self.feature_extractor(x)
#print(f"FasterKAN x after feature_extractor shape: {x.shape}")
x = x.view(x.size(0), -1) # Flatten the output from the conv layers
#rint(f"FasterKAN x shape: {x.shape}")
# Pass through FasterKAN layers
for layer in self.faster_kan_layers:
#print("FasterKAN layer: \n", layer)
#print(f"FasterKAN x shape: {x.shape}")
x = layer(x)
#print(f"FasterKAN x shape: {x.shape}")
return x |