File size: 5,683 Bytes
49ebc1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import *
from torch.autograd import Function
from .feature_extractor import EnhancedFeatureExtractor
from .fasterkan_layers import FasterKANLayer

class FasterKAN(nn.Module):
    def __init__(
        self,
        layers_hidden: List[int],
        grid_min: float = -1.2,
        grid_max: float = 1.2,
        num_grids: int = 8,
        exponent: int = 2,
        inv_denominator: float = 0.5,
        train_grid: bool = False,        
        train_inv_denominator: bool = False,
        #use_base_update: bool = True,
        base_activation = None,
        spline_weight_init_scale: float = 1.0,
    ) -> None:
        super().__init__()
        self.layers = nn.ModuleList([
            FasterKANLayer(
                in_dim, out_dim,
                grid_min=grid_min,
                grid_max=grid_max,
                num_grids=num_grids,
                exponent = exponent,
                inv_denominator = inv_denominator,
                train_grid = train_grid ,
                train_inv_denominator = train_inv_denominator,
                #use_base_update=use_base_update,
                base_activation=base_activation,
                spline_weight_init_scale=spline_weight_init_scale,
            ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
        ])
        #print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))   
        #print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))  
        #print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])]) 
   
        #print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
        #print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
    
    def forward(self, x):
        for layer in self.layers:
            #print("FasterKAN layer: \n", layer)
            #print(f"FasterKAN x shape: {x.shape}")
            x = layer(x)
        return x

class FasterKANvolver(nn.Module):
    def __init__(
        self,
        layers_hidden: List[int],
        grid_min: float = -1.2,
        grid_max: float = 0.2,
        num_grids: int = 8,
        exponent: int = 2,
        inv_denominator: float = 0.5,
        train_grid: bool = False,        
        train_inv_denominator: bool = False,
        #use_base_update: bool = True,
        base_activation = None,
        spline_weight_init_scale: float = 1.0,
        view = [-1, 1, 28, 28],
    ) -> None:
        super(FasterKANvolver, self).__init__()
        
        self.view = view
        # Feature extractor with Convolutional layers
        self.feature_extractor = EnhancedFeatureExtractor(colors = view[1])
        """
        nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),  # 1 input channel (grayscale), 16 output channels
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        """
        # Calculate the flattened feature size after convolutional layers
        flat_features = 256 # XX channels, image size reduced to YxY
        
        # Update layers_hidden with the correct input size from conv layers
        layers_hidden = [flat_features] + layers_hidden
        #print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
        #print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))   
        #print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))   
        #print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])         
        
        # Define the FasterKAN layers
        self.faster_kan_layers = nn.ModuleList([
            FasterKANLayer(
                in_dim, out_dim,
                grid_min=grid_min,
                grid_max=grid_max,
                num_grids=num_grids,
                exponent=exponent,
                inv_denominator = 0.5,
                train_grid = False,        
                train_inv_denominator = False,
                #use_base_update=use_base_update,
                base_activation=base_activation,
                spline_weight_init_scale=spline_weight_init_scale,
            ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
        ])   
        #print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
        #print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)

    def forward(self, x):
        # Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
        #print(f"FasterKAN x view shape: {x.shape}")
        # Handle different input shapes based on the length of view
        x = x.view(self.view[0], self.view[1], self.view[2], self.view[3])
        #print(f"FasterKAN x view shape: {x.shape}")
        # Apply convolutional layers
        #print(f"FasterKAN x view shape: {x.shape}")
        x = self.feature_extractor(x)
        #print(f"FasterKAN x after feature_extractor shape: {x.shape}")
        x = x.view(x.size(0), -1)  # Flatten the output from the conv layers
        #rint(f"FasterKAN x shape: {x.shape}")
        
        # Pass through FasterKAN layers
        for layer in self.faster_kan_layers:
            #print("FasterKAN layer: \n", layer)
            #print(f"FasterKAN x shape: {x.shape}")
            x = layer(x)
            #print(f"FasterKAN x shape: {x.shape}")
        
        return x