IlayMalinyak commited on
Commit
49ebc1f
·
1 Parent(s): 707b3a3
model/0.0_cache_data ADDED
Binary file (840 Bytes). View file
 
model/0.0_config.yml ADDED
The diff for this file is too large to render. See raw diff
 
model/history.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ### Round 0 ###
2
+ init => 0.0
tasks/audio.py CHANGED
@@ -10,7 +10,7 @@ from torch.utils.data import DataLoader
10
  from .utils.evaluation import AudioEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
  from .utils.data import FFTDataset
13
- from .utils.models import DualEncoder
14
  from .utils.train import Trainer
15
  from .utils.data_utils import collate_fn, Container
16
  import yaml
@@ -70,13 +70,14 @@ async def evaluate_audio(request: AudioEvaluationRequest):
70
  model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
71
  model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
72
  conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
 
73
 
74
  test_dataset = FFTDataset(test_dataset)
75
  test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
76
 
77
- model = DualEncoder(model_args, model_args_f, conformer_args)
78
  model = model.to(device)
79
- state_dict = torch.load(model_args.checkpoint_path)
80
  new_state_dict = OrderedDict()
81
  for key, value in state_dict.items():
82
  if key.startswith('module.'):
@@ -95,8 +96,12 @@ async def evaluate_audio(request: AudioEvaluationRequest):
95
  accumulation_step=1, max_iter=np.inf,
96
  exp_name=f"frugal_cnnencoder_inference")
97
  predictions, true_labels, acc = trainer.predict(test_dl, device=device)
 
 
98
  # Make random predictions (placeholder for actual model inference)
99
  print("accuracy: ", acc)
 
 
100
 
101
  #--------------------------------------------------------------------------------------------
102
  # YOUR MODEL INFERENCE STOPS HERE
@@ -128,15 +133,15 @@ async def evaluate_audio(request: AudioEvaluationRequest):
128
 
129
  return results
130
 
131
- # if __name__ == "__main__":
132
  # with open("../logs//token.txt", "r") as f:
133
  # api_key = f.read()
134
  # login(api_key)
135
  # # Create a sample request object
136
- # sample_request = AudioEvaluationRequest(
137
- # dataset_name="rfcx/frugalai", # Replace with actual dataset name
138
- # test_size=0.2, # Example values
139
- # test_seed=42
140
- # )
141
  #
142
- # asyncio.run(evaluate_audio(sample_request))
 
10
  from .utils.evaluation import AudioEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
  from .utils.data import FFTDataset
13
+ from .utils.models import DualEncoder, CNNKan
14
  from .utils.train import Trainer
15
  from .utils.data_utils import collate_fn, Container
16
  import yaml
 
70
  model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
71
  model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
72
  conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
73
+ kan_args = Container(**yaml.safe_load(open(args_path, 'r'))['KAN'])
74
 
75
  test_dataset = FFTDataset(test_dataset)
76
  test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
77
 
78
+ model = CNNKan(model_args, conformer_args, kan_args.get_dict())
79
  model = model.to(device)
80
+ state_dict = torch.load(data_args.checkpoint_path)
81
  new_state_dict = OrderedDict()
82
  for key, value in state_dict.items():
83
  if key.startswith('module.'):
 
96
  accumulation_step=1, max_iter=np.inf,
97
  exp_name=f"frugal_cnnencoder_inference")
98
  predictions, true_labels, acc = trainer.predict(test_dl, device=device)
99
+ # true_labels = test_dataset["label"]
100
+
101
  # Make random predictions (placeholder for actual model inference)
102
  print("accuracy: ", acc)
103
+ print("predictions: ", len(predictions))
104
+ print("true_labels: ", len(true_labels))
105
 
106
  #--------------------------------------------------------------------------------------------
107
  # YOUR MODEL INFERENCE STOPS HERE
 
133
 
134
  return results
135
 
136
+ if __name__ == "__main__":
137
  # with open("../logs//token.txt", "r") as f:
138
  # api_key = f.read()
139
  # login(api_key)
140
  # # Create a sample request object
141
+ sample_request = AudioEvaluationRequest(
142
+ dataset_name="rfcx/frugalai", # Replace with actual dataset name
143
+ test_size=0.2, # Example values
144
+ test_seed=42
145
+ )
146
  #
147
+ asyncio.run(evaluate_audio(sample_request))
tasks/models/frugal_2025-01-21/CNNEncoder_frugal_2.json ADDED
The diff for this file is too large to render. See raw diff
 
tasks/models/frugal_2025-01-21/frugal_kan_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28e0188edab4879996cc960d2dc79641460b270af9c5ac7d3eacad1f5e96da39
3
+ size 1714830
tasks/run.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from .utils.data import FFTDataset, SplitDataset
3
+ from datasets import load_dataset
4
+ from .utils.train import Trainer
5
+ from .utils.models import CNNKan, KanEncoder
6
+ from .utils.data_utils import *
7
+ from huggingface_hub import login
8
+ import yaml
9
+ import datetime
10
+ import json
11
+ import numpy as np
12
+
13
+ # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ current_date = datetime.date.today().strftime("%Y-%m-%d")
15
+ datetime_dir = f"frugal_{current_date}"
16
+ args_dir = 'tasks/utils/config.yaml'
17
+ data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
18
+ exp_num = data_args.exp_num
19
+ model_name = data_args.model_name
20
+ model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
21
+ model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
22
+ conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
23
+ kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
24
+ if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
25
+ os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
26
+
27
+ with open("../logs//token.txt", "r") as f:
28
+ api_key = f.read()
29
+
30
+ # local_rank, world_size, gpus_per_node = setup()
31
+ local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ login(api_key)
33
+ dataset = load_dataset("rfcx/frugalai", streaming=True)
34
+
35
+ train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True)
36
+
37
+ train_dl = DataLoader(train_ds, batch_size=data_args.batch_size, collate_fn=collate_fn)
38
+
39
+ val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False)
40
+
41
+ val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
42
+
43
+ test_ds = FFTDataset(dataset["test"])
44
+ test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
45
+
46
+ # for i, batch in enumerate(train_dl):
47
+ # x, x_f, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
48
+ # print(x.shape, x_f.shape, y.shape)
49
+ # if i > 10:
50
+ # break
51
+ # exit()
52
+
53
+ # model = DualEncoder(model_args, model_args_f, conformer_args)
54
+ # model = FasterKAN([18000,64,64,16,1])
55
+ model = CNNKan(model_args, conformer_args, kan_args.get_dict())
56
+ # model.kan.speed()
57
+ # model = KanEncoder(kan_args.get_dict())
58
+ model = model.to(local_rank)
59
+ # model = DDP(model, device_ids=[local_rank], output_device=local_rank)
60
+ num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
61
+ print(f"Number of parameters: {num_params}")
62
+
63
+ loss_fn = torch.nn.BCEWithLogitsLoss()
64
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
65
+ total_steps = int(data_args.num_epochs) * 1000
66
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
67
+ T_max=total_steps,
68
+ eta_min=float((5e-4)/10))
69
+
70
+ # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
71
+ # print(f"Missing keys: {missing}")
72
+ # print(f"Unexpected keys: {unexpected}")
73
+
74
+ trainer = Trainer(model=model, optimizer=optimizer,
75
+ criterion=loss_fn, output_dim=model_args.output_dim, scaler=None,
76
+ scheduler=None, train_dataloader=train_dl,
77
+ val_dataloader=val_dl, device=local_rank,
78
+ exp_num=datetime_dir, log_path=data_args.log_dir,
79
+ range_update=None,
80
+ accumulation_step=1, max_iter=np.inf,
81
+ exp_name=f"frugal_kan_{exp_num}")
82
+ fit_res = trainer.fit(num_epochs=100, device=local_rank,
83
+ early_stopping=10, only_p=False, best='loss', conf=True)
84
+ output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
85
+ with open(output_filename, "w") as f:
86
+ json.dump(fit_res, f, indent=2)
87
+ preds, acc = trainer.predict(test_dl, local_rank)
88
+ print(f"Accuracy: {acc}")
89
+
90
+
91
+
92
+
93
+
94
+
95
+
tasks/utils/config.yaml CHANGED
@@ -1,34 +1,41 @@
1
  Data:
2
  # Basics
3
- log_dir: '/data/frugal/logs'
4
  # Data
5
- dataset: "KeplerDataset"
6
- data_dir: '/data/lightPred/data'
7
  model_name: "CNNEncoder"
8
- batch_size: 16
9
- num_epochs: 1000
10
  exp_num: 2
11
  max_len_spectra: 4096
12
  max_days_lc: 270
13
  lc_freq: 0.0208
14
  create_umap: True
 
15
 
16
  CNNEncoder:
17
  # Model
18
- in_channels: 1
19
  num_layers: 4
20
  stride: 1
21
- encoder_dims: [32,64,128,256]
22
  kernel_size: 3
23
  dropout_p: 0.3
24
  output_dim: 2
25
  beta: 1
26
- load_checkpoint: True
27
  checkpoint_num: 1
28
  activation: "silu"
29
  sine_w0: 1.0
30
- avg_output: True
31
- checkpoint_path: 'tasks/models/frugal_2025-01-10/frugal_cnnencoder_2.pth'
 
 
 
 
 
 
32
 
33
  CNNEncoder_f:
34
  # Model
@@ -50,7 +57,7 @@ CNNEncoder_f:
50
  Conformer:
51
  encoder: ["mhsa_pro", "conv"]
52
  timeshift: false
53
- num_layers: 8
54
  encoder_dim: 128
55
  num_heads: 8
56
  kernel_size: 3
 
1
  Data:
2
  # Basics
3
+ log_dir: 'tasks/models'
4
  # Data
5
+ dataset: "FFTDataset"
6
+ data_dir: None
7
  model_name: "CNNEncoder"
8
+ batch_size: 32
9
+ num_epochs: 10
10
  exp_num: 2
11
  max_len_spectra: 4096
12
  max_days_lc: 270
13
  lc_freq: 0.0208
14
  create_umap: True
15
+ checkpoint_path: 'tasks/models/frugal_2025-01-21/frugal_kan_2.pth'
16
 
17
  CNNEncoder:
18
  # Model
19
+ in_channels: 2
20
  num_layers: 4
21
  stride: 1
22
+ encoder_dims: [32,64,128]
23
  kernel_size: 3
24
  dropout_p: 0.3
25
  output_dim: 2
26
  beta: 1
27
+ load_checkpoint: False
28
  checkpoint_num: 1
29
  activation: "silu"
30
  sine_w0: 1.0
31
+ avg_output: False
32
+
33
+ KAN:
34
+ layers_hidden: [1125,32,8,8,1]
35
+ grid_min: -1.2
36
+ grid_max: 1.2
37
+ num_grids: 8
38
+ exponent: 2
39
 
40
  CNNEncoder_f:
41
  # Model
 
57
  Conformer:
58
  encoder: ["mhsa_pro", "conv"]
59
  timeshift: false
60
+ num_layers: 4
61
  encoder_dim: 128
62
  num_heads: 8
63
  kernel_size: 3
tasks/utils/data.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
2
  from torch.utils.data import IterableDataset
3
  from torch.fft import fft
 
4
  from itertools import tee
5
  import random
6
  import torchaudio.transforms as T
@@ -24,20 +25,25 @@ class SplitDataset(IterableDataset):
24
 
25
 
26
  class FFTDataset(IterableDataset):
27
- def __init__(self, original_dataset, orig_sample_rate=12000, target_sample_rate=6000):
28
  self.dataset = original_dataset
29
  self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
 
30
 
31
  def __iter__(self):
32
  for item in self.dataset:
33
  # Assuming your audio data is in item['audio']
34
  # Modify this based on your actual data structure
35
  audio_data = torch.tensor(item['audio']['array']).float()
36
- if len(audio_data) == 0:
37
- continue
38
- resampled_audio = self.resampler(audio_data)
39
- fft_data = fft(resampled_audio)
 
 
 
40
 
41
  # Update the item with FFT data
42
  item['audio']['fft'] = fft_data
 
43
  yield item
 
1
  import torch
2
  from torch.utils.data import IterableDataset
3
  from torch.fft import fft
4
+ import torch.nn.functional as F
5
  from itertools import tee
6
  import random
7
  import torchaudio.transforms as T
 
25
 
26
 
27
  class FFTDataset(IterableDataset):
28
+ def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
29
  self.dataset = original_dataset
30
  self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
31
+ self.max_len = max_len
32
 
33
  def __iter__(self):
34
  for item in self.dataset:
35
  # Assuming your audio data is in item['audio']
36
  # Modify this based on your actual data structure
37
  audio_data = torch.tensor(item['audio']['array']).float()
38
+ # pad audio
39
+ # if len(audio_data) == 0:
40
+ # continue
41
+ pad_len = self.max_len - len(audio_data)
42
+ audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
43
+ audio_data = self.resampler(audio_data)
44
+ fft_data = fft(audio_data)
45
 
46
  # Update the item with FFT data
47
  item['audio']['fft'] = fft_data
48
+ item['audio']['array'] = audio_data
49
  yield item
tasks/utils/kan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .fasterkan import FasterKAN, FasterKANLayer, FasterKANvolver
tasks/utils/kan/fasterkan.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import *
6
+ from torch.autograd import Function
7
+ from .feature_extractor import EnhancedFeatureExtractor
8
+ from .fasterkan_layers import FasterKANLayer
9
+
10
+ class FasterKAN(nn.Module):
11
+ def __init__(
12
+ self,
13
+ layers_hidden: List[int],
14
+ grid_min: float = -1.2,
15
+ grid_max: float = 1.2,
16
+ num_grids: int = 8,
17
+ exponent: int = 2,
18
+ inv_denominator: float = 0.5,
19
+ train_grid: bool = False,
20
+ train_inv_denominator: bool = False,
21
+ #use_base_update: bool = True,
22
+ base_activation = None,
23
+ spline_weight_init_scale: float = 1.0,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.layers = nn.ModuleList([
27
+ FasterKANLayer(
28
+ in_dim, out_dim,
29
+ grid_min=grid_min,
30
+ grid_max=grid_max,
31
+ num_grids=num_grids,
32
+ exponent = exponent,
33
+ inv_denominator = inv_denominator,
34
+ train_grid = train_grid ,
35
+ train_inv_denominator = train_inv_denominator,
36
+ #use_base_update=use_base_update,
37
+ base_activation=base_activation,
38
+ spline_weight_init_scale=spline_weight_init_scale,
39
+ ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
40
+ ])
41
+ #print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
42
+ #print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
43
+ #print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
44
+
45
+ #print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
46
+ #print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
47
+
48
+ def forward(self, x):
49
+ for layer in self.layers:
50
+ #print("FasterKAN layer: \n", layer)
51
+ #print(f"FasterKAN x shape: {x.shape}")
52
+ x = layer(x)
53
+ return x
54
+
55
+ class FasterKANvolver(nn.Module):
56
+ def __init__(
57
+ self,
58
+ layers_hidden: List[int],
59
+ grid_min: float = -1.2,
60
+ grid_max: float = 0.2,
61
+ num_grids: int = 8,
62
+ exponent: int = 2,
63
+ inv_denominator: float = 0.5,
64
+ train_grid: bool = False,
65
+ train_inv_denominator: bool = False,
66
+ #use_base_update: bool = True,
67
+ base_activation = None,
68
+ spline_weight_init_scale: float = 1.0,
69
+ view = [-1, 1, 28, 28],
70
+ ) -> None:
71
+ super(FasterKANvolver, self).__init__()
72
+
73
+ self.view = view
74
+ # Feature extractor with Convolutional layers
75
+ self.feature_extractor = EnhancedFeatureExtractor(colors = view[1])
76
+ """
77
+ nn.Sequential(
78
+ nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
79
+ nn.ReLU(),
80
+ nn.MaxPool2d(2, 2),
81
+ nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
82
+ nn.ReLU(),
83
+ nn.MaxPool2d(2, 2)
84
+ )
85
+ """
86
+ # Calculate the flattened feature size after convolutional layers
87
+ flat_features = 256 # XX channels, image size reduced to YxY
88
+
89
+ # Update layers_hidden with the correct input size from conv layers
90
+ layers_hidden = [flat_features] + layers_hidden
91
+ #print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
92
+ #print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
93
+ #print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
94
+ #print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
95
+
96
+ # Define the FasterKAN layers
97
+ self.faster_kan_layers = nn.ModuleList([
98
+ FasterKANLayer(
99
+ in_dim, out_dim,
100
+ grid_min=grid_min,
101
+ grid_max=grid_max,
102
+ num_grids=num_grids,
103
+ exponent=exponent,
104
+ inv_denominator = 0.5,
105
+ train_grid = False,
106
+ train_inv_denominator = False,
107
+ #use_base_update=use_base_update,
108
+ base_activation=base_activation,
109
+ spline_weight_init_scale=spline_weight_init_scale,
110
+ ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
111
+ ])
112
+ #print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
113
+ #print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
114
+
115
+ def forward(self, x):
116
+ # Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
117
+ #print(f"FasterKAN x view shape: {x.shape}")
118
+ # Handle different input shapes based on the length of view
119
+ x = x.view(self.view[0], self.view[1], self.view[2], self.view[3])
120
+ #print(f"FasterKAN x view shape: {x.shape}")
121
+ # Apply convolutional layers
122
+ #print(f"FasterKAN x view shape: {x.shape}")
123
+ x = self.feature_extractor(x)
124
+ #print(f"FasterKAN x after feature_extractor shape: {x.shape}")
125
+ x = x.view(x.size(0), -1) # Flatten the output from the conv layers
126
+ #rint(f"FasterKAN x shape: {x.shape}")
127
+
128
+ # Pass through FasterKAN layers
129
+ for layer in self.faster_kan_layers:
130
+ #print("FasterKAN layer: \n", layer)
131
+ #print(f"FasterKAN x shape: {x.shape}")
132
+ x = layer(x)
133
+ #print(f"FasterKAN x shape: {x.shape}")
134
+
135
+ return x
tasks/utils/kan/fasterkan_basis.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import *
6
+ from torch.autograd import Function
7
+
8
+ class RSWAFFunction(Function):
9
+ @staticmethod
10
+ def forward(ctx, input, grid, inv_denominator, train_grid, train_inv_denominator):
11
+ # Compute the forward pass
12
+ #print('\n')
13
+ #print(f"Forward pass - grid: {(grid[0].item(),grid[-1].item())}, inv_denominator: {inv_denominator.item()}")
14
+
15
+ #print(f"grid.shape: {grid.shape }")
16
+ #print(f"grid: {(grid[0],grid[-1]) }")
17
+ #print(f"inv_denominator.shape: {inv_denominator.shape }")
18
+ #print(f"inv_denominator: {inv_denominator }")
19
+ diff = (input[..., None] - grid)
20
+ diff_mul = diff.mul(inv_denominator)
21
+ tanh_diff = torch.tanh(diff)
22
+ tanh_diff_deriviative = -tanh_diff.mul(tanh_diff) + 1 # sech^2(x) = 1 - tanh^2(x)
23
+
24
+ # Save tensors for backward pass
25
+ ctx.save_for_backward(input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator)
26
+ ctx.train_grid = train_grid
27
+ ctx.train_inv_denominator = train_inv_denominator
28
+
29
+ return tanh_diff_deriviative
30
+
31
+ ##### SOS NOT SURE HOW grad_inv_denominator, grad_grid ARE CALCULATED CORRECTLY YET
32
+ ##### MUST CHECK https://github.com/pytorch/pytorch/issues/74802
33
+ ##### MUST CHECK https://www.changjiangcai.com/studynotes/2020-10-18-Custom-Function-Extending-PyTorch/
34
+ ##### MUST CHECK https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html
35
+ ##### MUST CHECK https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
36
+ ##### MUST CHECK https://gist.github.com/Hanrui-Wang/bf225dc0ccb91cdce160539c0acc853a
37
+
38
+ @staticmethod
39
+ def backward(ctx, grad_output):
40
+ # Retrieve saved tensors
41
+ input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator = ctx.saved_tensors
42
+ grad_grid = None
43
+ grad_inv_denominator = None
44
+
45
+ #print(f"tanh_diff_deriviative shape: {tanh_diff_deriviative.shape }")
46
+ #print(f"tanh_diff shape: {tanh_diff.shape }")
47
+ #print(f"grad_output shape: {grad_output.shape }")
48
+
49
+ # Compute the backward pass for the input
50
+ grad_input = -2 * tanh_diff * tanh_diff_deriviative * grad_output
51
+ #print(f"Backward pass 1 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
52
+ #print(f"grad_input shape: {grad_input.shape }")
53
+ #print(f"grad_input.sum(dim=-1): {grad_input.sum(dim=-1).shape}")
54
+ grad_input = grad_input.sum(dim=-1).mul(inv_denominator)
55
+ #print(f"Backward pass 2 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
56
+ #print(f"grad_input: {grad_input}")
57
+ #print(f"grad_input shape: {grad_input.shape }")
58
+
59
+ # Compute the backward pass for grid
60
+ if ctx.train_grid:
61
+ #print('\n')
62
+ #print(f"grad_grid shape: {grad_grid.shape }")
63
+ grad_grid = -inv_denominator * grad_output.sum(dim=0).sum(dim=0)#-(inv_denominator * grad_output * tanh_diff_deriviative).sum(dim=0) #-inv_denominator * grad_output.sum(dim=0).sum(dim=0)
64
+ #print(f"Backward pass - grad_grid: {(grad_grid[0].item(),grad_grid[-1].item())}")
65
+ #print(f"grad_grid.shape: {grad_grid.shape }")
66
+ #print(f"grad_grid: {(grad_grid[0],grad_grid[-1]) }")
67
+ #print(f"inv_denominator shape: {inv_denominator.shape }")
68
+ #print(f"grad_grid shape: {grad_grid.shape }")
69
+
70
+ # Compute the backward pass for inv_denominator
71
+ if ctx.train_inv_denominator:
72
+ grad_inv_denominator = (grad_output* diff).sum() #(grad_output * diff * tanh_diff_deriviative).sum() #(grad_output* diff).sum()
73
+ #print(f"Backward pass - grad_inv_denominator: {grad_inv_denominator.item()}")
74
+ #print(f"diff shape: {diff.shape }")
75
+
76
+ #print(f"grad_inv_denominator shape: {grad_inv_denominator.shape }")
77
+ #print(f"grad_inv_denominator : {grad_inv_denominator }")
78
+
79
+ return grad_input, grad_grid, grad_inv_denominator, None, None # same number as tensors or parameters
80
+
81
+
82
+
83
+ class ReflectionalSwitchFunction(nn.Module):
84
+ def __init__(
85
+ self,
86
+ grid_min: float = -1.2,
87
+ grid_max: float = 0.2,
88
+ num_grids: int = 8,
89
+ exponent: int = 2,
90
+ inv_denominator: float = 0.5,
91
+ train_grid: bool = False,
92
+ train_inv_denominator: bool = False,
93
+ ):
94
+ super().__init__()
95
+ grid = torch.linspace(grid_min, grid_max, num_grids)
96
+ self.train_grid = torch.tensor(train_grid, dtype=torch.bool)
97
+ self.train_inv_denominator = torch.tensor(train_inv_denominator, dtype=torch.bool)
98
+ self.grid = torch.nn.Parameter(grid, requires_grad=train_grid)
99
+ #print(f"grid initial shape: {self.grid.shape }")
100
+ self.inv_denominator = torch.nn.Parameter(torch.tensor(inv_denominator, dtype=torch.float32), requires_grad=train_inv_denominator) # Cache the inverse of the denominator
101
+
102
+ def forward(self, x):
103
+ return RSWAFFunction.apply(x, self.grid, self.inv_denominator, self.train_grid, self.train_inv_denominator)
104
+
105
+
106
+ class SplineLinear(nn.Linear):
107
+ def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
108
+ self.init_scale = init_scale
109
+ super().__init__(in_features, out_features, bias=False, **kw)
110
+
111
+ def reset_parameters(self) -> None:
112
+ nn.init.xavier_uniform_(self.weight) # Using Xavier Uniform initialization
tasks/utils/kan/fasterkan_layers.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import *
6
+ from torch.autograd import Function
7
+ from .fasterkan_basis import ReflectionalSwitchFunction, SplineLinear
8
+
9
+ class FasterKANLayer(nn.Module):
10
+ def __init__(
11
+ self,
12
+ input_dim: int,
13
+ output_dim: int,
14
+ grid_min: float = -1.2,
15
+ grid_max: float = 0.2,
16
+ num_grids: int = 8,
17
+ exponent: int = 2,
18
+ inv_denominator: float = 0.5,
19
+ train_grid: bool = False,
20
+ train_inv_denominator: bool = False,
21
+ #use_base_update: bool = True,
22
+ base_activation = F.silu,
23
+ spline_weight_init_scale: float = 0.667,
24
+ ) -> None:
25
+ super().__init__()
26
+ self.layernorm = nn.LayerNorm(input_dim)
27
+ self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, inv_denominator, train_grid, train_inv_denominator)
28
+ self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
29
+ #self.use_base_update = use_base_update
30
+ #if use_base_update:
31
+ # self.base_activation = base_activation
32
+ # self.base_linear = nn.Linear(input_dim, output_dim)
33
+
34
+ def forward(self, x):
35
+ #print("Shape before LayerNorm:", x.shape) # Debugging line to check the input shape
36
+ x = self.layernorm(x)
37
+ #print("Shape After LayerNorm:", x.shape)
38
+ spline_basis = self.rbf(x).view(x.shape[0], -1)
39
+ #print("spline_basis:", spline_basis.shape)
40
+
41
+ #print("-------------------------")
42
+ #ret = 0
43
+ ret = self.spline_linear(spline_basis)
44
+ #print("spline_basis.shape[:-2]:", spline_basis.shape[:-2])
45
+ #print("*spline_basis.shape[:-2]:", *spline_basis.shape[:-2])
46
+ #print("spline_basis.view(*spline_basis.shape[:-2], -1):", spline_basis.view(*spline_basis.shape[:-2], -1).shape)
47
+ #print("ret:", ret.shape)
48
+ #print("-------------------------")
49
+ #if self.use_base_update:
50
+ #base = self.base_linear(self.base_activation(x))
51
+ #print("self.base_activation(x):", self.base_activation(x).shape)
52
+ #print("base:", base.shape)
53
+ #print("@@@@@@@@@")
54
+ #ret += base
55
+ return ret
56
+
57
+
58
+ #spline_basis = spline_basis.reshape(x.shape[0], -1) # Reshape to [batch_size, input_dim * num_grids]
59
+ #print("spline_basis:", spline_basis.shape)
60
+
61
+ #spline_weight = self.spline_weight.view(-1, self.spline_weight.shape[0]) # Reshape to [input_dim * num_grids, output_dim]
62
+ #print("spline_weight:", spline_weight.shape)
63
+
64
+ #spline = torch.matmul(spline_basis, spline_weight) # Resulting shape: [batch_size, output_dim]
65
+
66
+ #print("-------------------------")
67
+ #print("Base shape:", base.shape)
68
+ #print("Spline shape:", spline.shape)
69
+ #print("@@@@@@@@@")
70
+
71
+
72
+ class FasterKAN(nn.Module):
73
+ def __init__(
74
+ self,
75
+ layers_hidden: List[int],
76
+ grid_min: float = -1.2,
77
+ grid_max: float = 0.2,
78
+ num_grids: int = 8,
79
+ exponent: int = 2,
80
+ inv_denominator: float = 0.5,
81
+ train_grid: bool = False,
82
+ train_inv_denominator: bool = False,
83
+ #use_base_update: bool = True,
84
+ base_activation = None,
85
+ spline_weight_init_scale: float = 1.0,
86
+ ) -> None:
87
+ super().__init__()
88
+ self.layers = nn.ModuleList([
89
+ FasterKANLayer(
90
+ in_dim, out_dim,
91
+ grid_min=grid_min,
92
+ grid_max=grid_max,
93
+ num_grids=num_grids,
94
+ exponent = exponent,
95
+ inv_denominator = inv_denominator,
96
+ train_grid = train_grid ,
97
+ train_inv_denominator = train_inv_denominator,
98
+ #use_base_update=use_base_update,
99
+ base_activation=base_activation,
100
+ spline_weight_init_scale=spline_weight_init_scale,
101
+ ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
102
+ ])
103
+ #print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
104
+ #print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
105
+ #print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
106
+
107
+ #print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
108
+ #print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
109
+
110
+ def forward(self, x):
111
+ for layer in self.layers:
112
+ #print("FasterKAN layer: \n", layer)
113
+ #print(f"FasterKAN x shape: {x.shape}")
114
+ x = layer(x)
115
+ return x
116
+
117
+
118
+
119
+ class BasicResBlock(nn.Module):
120
+ def __init__(self, in_channels, out_channels, stride=1):
121
+ super(BasicResBlock, self).__init__()
122
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
123
+ self.bn1 = nn.BatchNorm2d(out_channels)
124
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
125
+ self.bn2 = nn.BatchNorm2d(out_channels)
126
+
127
+ self.downsample = nn.Sequential()
128
+ if stride != 1 or in_channels != out_channels:
129
+ self.downsample = nn.Sequential(
130
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
131
+ nn.BatchNorm2d(out_channels)
132
+ )
133
+
134
+ def forward(self, x):
135
+ identity = self.downsample(x)
136
+
137
+ out = F.relu(self.bn1(self.conv1(x)))
138
+ out = self.bn2(self.conv2(out))
139
+ out += identity
140
+ out = F.relu(out)
141
+
142
+ return out
143
+
144
+ class SEBlock(nn.Module):
145
+ def __init__(self, channel, reduction=16):
146
+ super(SEBlock, self).__init__()
147
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
148
+ self.fc = nn.Sequential(
149
+ nn.Linear(channel, channel // reduction, bias=False),
150
+ nn.ReLU(inplace=True),
151
+ nn.Linear(channel // reduction, channel, bias=False),
152
+ nn.Sigmoid()
153
+ )
154
+
155
+ def forward(self, x):
156
+ b, c, _, _ = x.size()
157
+ y = self.avg_pool(x).view(b, c)
158
+ y = self.fc(y).view(b, c, 1, 1)
159
+ return x * y.expand_as(x)
160
+
161
+
162
+ class DepthwiseSeparableConv(nn.Module):
163
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
164
+ super(DepthwiseSeparableConv, self).__init__()
165
+ self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
166
+ stride=stride, padding=padding, groups=in_channels)
167
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
168
+
169
+ def forward(self, x):
170
+ x = self.depthwise(x)
171
+ x = self.pointwise(x)
172
+ return x
173
+
174
+ class SelfAttention(nn.Module):
175
+ def __init__(self, in_channels):
176
+ super(SelfAttention, self).__init__()
177
+ self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
178
+ self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
179
+ self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
180
+ self.gamma = nn.Parameter(torch.zeros(1))
181
+
182
+ def forward(self, x):
183
+ batch_size, C, width, height = x.size()
184
+ proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
185
+ proj_key = self.key_conv(x).view(batch_size, -1, width * height)
186
+ energy = torch.bmm(proj_query, proj_key)
187
+ attention = F.softmax(energy, dim=-1)
188
+ proj_value = self.value_conv(x).view(batch_size, -1, width * height)
189
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
190
+ out = out.view(batch_size, C, width, height)
191
+ out = self.gamma * out + x
192
+ return out
193
+
194
+ class EnhancedFeatureExtractor(nn.Module):
195
+ def __init__(self):
196
+ super(EnhancedFeatureExtractor, self).__init__()
197
+ self.initial_layers = nn.Sequential(
198
+ nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # Increased number of filters
199
+ nn.ReLU(),
200
+ nn.BatchNorm2d(32), # Added Batch Normalization
201
+ nn.MaxPool2d(2, 2),
202
+ nn.Dropout(0.25), # Added Dropout
203
+ BasicResBlock(32, 64),
204
+ SEBlock(64, reduction=16), # Squeeze-and-Excitation block
205
+ nn.MaxPool2d(2, 2),
206
+ nn.Dropout(0.25), # Added Dropout
207
+ DepthwiseSeparableConv(64, 128, kernel_size=3), # Increased number of filters
208
+ nn.ReLU(),
209
+ BasicResBlock(128, 256),
210
+ SEBlock(256, reduction=16),
211
+ nn.MaxPool2d(2, 2),
212
+ nn.Dropout(0.25), # Added Dropout
213
+ SelfAttention(256), # Added Self-Attention layer
214
+ )
215
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Global Average Pooling
216
+
217
+ def forward(self, x):
218
+ x = self.initial_layers(x)
219
+ x = self.global_avg_pool(x)
220
+ x = x.view(x.size(0), -1) # Flatten the output for fully connected layers
221
+ return x
222
+
223
+ class FasterKANvolver(nn.Module):
224
+ def __init__(
225
+ self,
226
+ layers_hidden: List[int],
227
+ grid_min: float = -1.2,
228
+ grid_max: float = 0.2,
229
+ num_grids: int = 8,
230
+ exponent: int = 2,
231
+ inv_denominator: float = 0.5,
232
+ train_grid: bool = False,
233
+ train_inv_denominator: bool = False,
234
+ #use_base_update: bool = True,
235
+ base_activation = None,
236
+ spline_weight_init_scale: float = 1.0,
237
+ ) -> None:
238
+ super(FasterKANvolver, self).__init__()
239
+
240
+ # Feature extractor with Convolutional layers
241
+ self.feature_extractor = EnhancedFeatureExtractor()
242
+ """
243
+ nn.Sequential(
244
+ nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
245
+ nn.ReLU(),
246
+ nn.MaxPool2d(2, 2),
247
+ nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
248
+ nn.ReLU(),
249
+ nn.MaxPool2d(2, 2)
250
+ )
251
+ """
252
+
253
+ # Calculate the flattened feature size after convolutional layers
254
+ flat_features = 256 # XX channels, image size reduced to YxY
255
+
256
+ # Update layers_hidden with the correct input size from conv layers
257
+ layers_hidden = [flat_features] + layers_hidden
258
+ #print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
259
+ #print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
260
+ #print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
261
+ #print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
262
+
263
+ # Define the FasterKAN layers
264
+ self.faster_kan_layers = nn.ModuleList([
265
+ FasterKANLayer(
266
+ in_dim, out_dim,
267
+ grid_min=grid_min,
268
+ grid_max=grid_max,
269
+ num_grids=num_grids,
270
+ exponent=exponent,
271
+ inv_denominator = 0.5,
272
+ train_grid = False,
273
+ train_inv_denominator = False,
274
+ #use_base_update=use_base_update,
275
+ base_activation=base_activation,
276
+ spline_weight_init_scale=spline_weight_init_scale,
277
+ ) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
278
+ ])
279
+ #print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
280
+ #print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
281
+
282
+ def forward(self, x):
283
+ # Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
284
+ #print(f"FasterKAN x view shape: {x.shape}")
285
+ x = x.view(-1, 3, 32,32)
286
+ #print(f"FasterKAN x view shape: {x.shape}")
287
+ # Apply convolutional layers
288
+ #print(f"FasterKAN x view shape: {x.shape}")
289
+ x = self.feature_extractor(x)
290
+ #print(f"FasterKAN x after feature_extractor shape: {x.shape}")
291
+ x = x.view(x.size(0), -1) # Flatten the output from the conv layers
292
+ #rint(f"FasterKAN x shape: {x.shape}")
293
+
294
+ # Pass through FasterKAN layers
295
+ for layer in self.faster_kan_layers:
296
+ #print("FasterKAN layer: \n", layer)
297
+ #print(f"FasterKAN x shape: {x.shape}")
298
+ x = layer(x)
299
+ #print(f"FasterKAN x shape: {x.shape}")
300
+
301
+ return x
tasks/utils/kan/feature_extractor.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import *
6
+ from torch.autograd import Function
7
+
8
+
9
+ class BasicResBlock(nn.Module):
10
+ def __init__(self, in_channels, out_channels, stride=1):
11
+ super(BasicResBlock, self).__init__()
12
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
13
+ self.bn1 = nn.BatchNorm2d(out_channels)
14
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
15
+ self.bn2 = nn.BatchNorm2d(out_channels)
16
+
17
+ self.downsample = nn.Sequential()
18
+ if stride != 1 or in_channels != out_channels:
19
+ self.downsample = nn.Sequential(
20
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
21
+ nn.BatchNorm2d(out_channels)
22
+ )
23
+
24
+ def forward(self, x):
25
+ identity = self.downsample(x)
26
+
27
+ out = F.relu(self.bn1(self.conv1(x)))
28
+ out = self.bn2(self.conv2(out))
29
+ out += identity
30
+ out = F.relu(out)
31
+
32
+ return out
33
+
34
+ class SEBlock(nn.Module):
35
+ def __init__(self, channel, reduction=16):
36
+ super(SEBlock, self).__init__()
37
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
38
+ self.fc = nn.Sequential(
39
+ nn.Linear(channel, channel // reduction, bias=False),
40
+ nn.ReLU(inplace=True),
41
+ nn.Linear(channel // reduction, channel, bias=False),
42
+ nn.Sigmoid()
43
+ )
44
+
45
+ def forward(self, x):
46
+ b, c, _, _ = x.size()
47
+ y = self.avg_pool(x).view(b, c)
48
+ y = self.fc(y).view(b, c, 1, 1)
49
+ return x * y.expand_as(x)
50
+
51
+
52
+ class DepthwiseSeparableConv(nn.Module):
53
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
54
+ super(DepthwiseSeparableConv, self).__init__()
55
+ self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
56
+ stride=stride, padding=padding, groups=in_channels)
57
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
58
+
59
+ def forward(self, x):
60
+ x = self.depthwise(x)
61
+ x = self.pointwise(x)
62
+ return x
63
+
64
+ class SelfAttention(nn.Module):
65
+ def __init__(self, in_channels):
66
+ super(SelfAttention, self).__init__()
67
+ self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
68
+ self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
69
+ self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
70
+ self.gamma = nn.Parameter(torch.zeros(1))
71
+
72
+ def forward(self, x):
73
+ batch_size, C, width, height = x.size()
74
+ proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
75
+ proj_key = self.key_conv(x).view(batch_size, -1, width * height)
76
+ energy = torch.bmm(proj_query, proj_key)
77
+ attention = F.softmax(energy, dim=-1)
78
+ proj_value = self.value_conv(x).view(batch_size, -1, width * height)
79
+ out = torch.bmm(proj_value, attention.permute(0, 2, 1))
80
+ out = out.view(batch_size, C, width, height)
81
+ out = self.gamma * out + x
82
+ return out
83
+
84
+ class EnhancedFeatureExtractor(nn.Module):
85
+ def __init__(self,
86
+ colors = 3):
87
+ super(EnhancedFeatureExtractor, self).__init__()
88
+ self.initial_layers = nn.Sequential(
89
+ nn.Conv2d(colors, 32, kernel_size=3, stride=1, padding=1), # Increased number of filters
90
+ nn.ReLU(),
91
+ nn.BatchNorm2d(32), # Added Batch Normalization
92
+ nn.MaxPool2d(2, 2),
93
+ nn.Dropout(0.25), # Added Dropout
94
+ BasicResBlock(32, 64),
95
+ SEBlock(64, reduction=16), # Squeeze-and-Excitation block
96
+ nn.MaxPool2d(2, 2),
97
+ nn.Dropout(0.25), # Added Dropout
98
+ DepthwiseSeparableConv(64, 128, kernel_size=3), # Increased number of filters
99
+ nn.ReLU(),
100
+ BasicResBlock(128, 256),
101
+ SEBlock(256, reduction=16),
102
+ nn.MaxPool2d(2, 2),
103
+ nn.Dropout(0.25), # Added Dropout
104
+ SelfAttention(256), # Added Self-Attention layer
105
+ )
106
+ self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Global Average Pooling
107
+
108
+ def forward(self, x):
109
+ x = self.initial_layers(x)
110
+ x = self.global_avg_pool(x)
111
+ x = x.view(x.size(0), -1) # Flatten the output for fully connected layers
112
+ return x
tasks/utils/models.py CHANGED
@@ -2,6 +2,8 @@ import torch
2
  import torch.nn as nn
3
  from .Modules.conformer import ConformerEncoder, ConformerDecoder
4
  from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
 
 
5
 
6
  class ConvBlock(nn.Module):
7
  def __init__(self, args, num_layer) -> None:
@@ -111,4 +113,29 @@ class DualEncoder(nn.Module):
111
  x1 = self.encoder_x(x)
112
  x2, _ = self.encoder_f(x)
113
  logits = torch.cat([x1, x2], dim=-1)
114
- return self.regressor(logits).squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  from .Modules.conformer import ConformerEncoder, ConformerDecoder
4
  from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
5
+ from .kan.fasterkan import FasterKAN
6
+ from kan import KAN
7
 
8
  class ConvBlock(nn.Module):
9
  def __init__(self, args, num_layer) -> None:
 
113
  x1 = self.encoder_x(x)
114
  x2, _ = self.encoder_f(x)
115
  logits = torch.cat([x1, x2], dim=-1)
116
+ return self.regressor(logits).squeeze()
117
+
118
+ class CNNKan(nn.Module):
119
+ def __init__(self, args, conformer_args, kan_args):
120
+ super().__init__()
121
+ self.backbone = CNNEncoder(args)
122
+ # self.kan = KAN(width=kan_args['layers_hidden'])
123
+ self.kan = FasterKAN(**kan_args)
124
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
125
+ x = self.backbone(x)
126
+ x = x.mean(dim=1)
127
+ return self.kan(x)
128
+
129
+ class KanEncoder(nn.Module):
130
+ def __init__(self, args):
131
+ super().__init__()
132
+ self.kan_x = FasterKAN(**args)
133
+ self.kan_f = FasterKAN(**args)
134
+ self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])
135
+
136
+ def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
137
+ x = self.kan_x(x)
138
+ f = self.kan_f(f)
139
+ out = torch.cat([x, f], dim=-1)
140
+ return self.kan_out(out)
141
+
tasks/utils/train.py CHANGED
@@ -74,8 +74,8 @@ class Trainer(object):
74
  lrs = []
75
  # self.optim_params['lr_history'] = []
76
  epochs_without_improvement = 0
77
- main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
78
-
79
  print(f"Starting training for {num_epochs} epochs")
80
  print("is main process: ", main_proccess, flush=True)
81
  global_time = time.time()
@@ -221,7 +221,8 @@ class Trainer(object):
221
  x = x.to(device).float()
222
  fft = fft.to(device).float()
223
  y = y.to(device).float()
224
- y_pred = self.model(fft)
 
225
  loss = self.criterion(y_pred, y)
226
  loss.backward()
227
  self.optimizer.step()
@@ -230,7 +231,7 @@ class Trainer(object):
230
  # get predicted classes
231
  probs = torch.sigmoid(y_pred)
232
  cls_pred = (probs > 0.5).float()
233
- acc = (cls_pred == y).sum()
234
  return loss, acc, y
235
 
236
  def eval_epoch(self, device, epoch):
@@ -257,10 +258,11 @@ class Trainer(object):
257
  x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
258
  x = x.to(device).float()
259
  fft = fft.to(device).float()
 
260
  y = y.to(device).float()
261
  with torch.no_grad():
262
- y_pred = self.model(fft)
263
- loss = self.criterion(y_pred, y)
264
  probs = torch.sigmoid(y_pred)
265
  cls_pred = (probs > 0.5).float()
266
  acc = (cls_pred == y).sum()
@@ -280,15 +282,16 @@ class Trainer(object):
280
  x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
281
  x = x.to(device).float()
282
  fft = fft.to(device).float()
 
283
  y = y.to(device).float()
284
  with torch.no_grad():
285
- y_pred = self.model(fft)
286
  loss = self.criterion(y_pred, y)
287
  probs = torch.sigmoid(y_pred)
288
  cls_pred = (probs > 0.5).float()
289
  acc = (cls_pred == y).sum()
290
- predictions.append(cls_pred.cpu().numpy())
291
- true_labels.append(y.cpu().numpy())
292
  all_accs += acc
293
  total += len(y)
294
  pbar.set_description("acc: {:.4f}".format(acc))
 
74
  lrs = []
75
  # self.optim_params['lr_history'] = []
76
  epochs_without_improvement = 0
77
+ # main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
78
+ main_proccess = True # change in a ddp setting
79
  print(f"Starting training for {num_epochs} epochs")
80
  print("is main process: ", main_proccess, flush=True)
81
  global_time = time.time()
 
221
  x = x.to(device).float()
222
  fft = fft.to(device).float()
223
  y = y.to(device).float()
224
+ x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
225
+ y_pred = self.model(x_fft).squeeze()
226
  loss = self.criterion(y_pred, y)
227
  loss.backward()
228
  self.optimizer.step()
 
231
  # get predicted classes
232
  probs = torch.sigmoid(y_pred)
233
  cls_pred = (probs > 0.5).float()
234
+ acc = (cls_pred == y).sum()
235
  return loss, acc, y
236
 
237
  def eval_epoch(self, device, epoch):
 
258
  x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
259
  x = x.to(device).float()
260
  fft = fft.to(device).float()
261
+ x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
262
  y = y.to(device).float()
263
  with torch.no_grad():
264
+ y_pred = self.model(x_fft).squeeze()
265
+ loss = self.criterion(y_pred.squeeze(), y)
266
  probs = torch.sigmoid(y_pred)
267
  cls_pred = (probs > 0.5).float()
268
  acc = (cls_pred == y).sum()
 
282
  x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
283
  x = x.to(device).float()
284
  fft = fft.to(device).float()
285
+ x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
286
  y = y.to(device).float()
287
  with torch.no_grad():
288
+ y_pred = self.model(x_fft).squeeze()
289
  loss = self.criterion(y_pred, y)
290
  probs = torch.sigmoid(y_pred)
291
  cls_pred = (probs > 0.5).float()
292
  acc = (cls_pred == y).sum()
293
+ predictions.extend(cls_pred.cpu().numpy())
294
+ true_labels.extend(y.cpu().numpy())
295
  all_accs += acc
296
  total += len(y)
297
  pbar.set_description("acc: {:.4f}".format(acc))