Spaces:
Sleeping
Sleeping
IlayMalinyak
commited on
Commit
·
49ebc1f
1
Parent(s):
707b3a3
kan
Browse files- model/0.0_cache_data +0 -0
- model/0.0_config.yml +0 -0
- model/history.txt +2 -0
- tasks/audio.py +15 -10
- tasks/models/frugal_2025-01-21/CNNEncoder_frugal_2.json +0 -0
- tasks/models/frugal_2025-01-21/frugal_kan_2.pth +3 -0
- tasks/run.py +95 -0
- tasks/utils/config.yaml +18 -11
- tasks/utils/data.py +11 -5
- tasks/utils/kan/__init__.py +1 -0
- tasks/utils/kan/fasterkan.py +135 -0
- tasks/utils/kan/fasterkan_basis.py +112 -0
- tasks/utils/kan/fasterkan_layers.py +301 -0
- tasks/utils/kan/feature_extractor.py +112 -0
- tasks/utils/models.py +28 -1
- tasks/utils/train.py +12 -9
model/0.0_cache_data
ADDED
Binary file (840 Bytes). View file
|
|
model/0.0_config.yml
ADDED
The diff for this file is too large to render.
See raw diff
|
|
model/history.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
### Round 0 ###
|
2 |
+
init => 0.0
|
tasks/audio.py
CHANGED
@@ -10,7 +10,7 @@ from torch.utils.data import DataLoader
|
|
10 |
from .utils.evaluation import AudioEvaluationRequest
|
11 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
12 |
from .utils.data import FFTDataset
|
13 |
-
from .utils.models import DualEncoder
|
14 |
from .utils.train import Trainer
|
15 |
from .utils.data_utils import collate_fn, Container
|
16 |
import yaml
|
@@ -70,13 +70,14 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
70 |
model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
|
71 |
model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
|
72 |
conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
|
|
|
73 |
|
74 |
test_dataset = FFTDataset(test_dataset)
|
75 |
test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
76 |
|
77 |
-
model =
|
78 |
model = model.to(device)
|
79 |
-
state_dict = torch.load(
|
80 |
new_state_dict = OrderedDict()
|
81 |
for key, value in state_dict.items():
|
82 |
if key.startswith('module.'):
|
@@ -95,8 +96,12 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
95 |
accumulation_step=1, max_iter=np.inf,
|
96 |
exp_name=f"frugal_cnnencoder_inference")
|
97 |
predictions, true_labels, acc = trainer.predict(test_dl, device=device)
|
|
|
|
|
98 |
# Make random predictions (placeholder for actual model inference)
|
99 |
print("accuracy: ", acc)
|
|
|
|
|
100 |
|
101 |
#--------------------------------------------------------------------------------------------
|
102 |
# YOUR MODEL INFERENCE STOPS HERE
|
@@ -128,15 +133,15 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
128 |
|
129 |
return results
|
130 |
|
131 |
-
|
132 |
# with open("../logs//token.txt", "r") as f:
|
133 |
# api_key = f.read()
|
134 |
# login(api_key)
|
135 |
# # Create a sample request object
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
#
|
142 |
-
|
|
|
10 |
from .utils.evaluation import AudioEvaluationRequest
|
11 |
from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
12 |
from .utils.data import FFTDataset
|
13 |
+
from .utils.models import DualEncoder, CNNKan
|
14 |
from .utils.train import Trainer
|
15 |
from .utils.data_utils import collate_fn, Container
|
16 |
import yaml
|
|
|
70 |
model_args = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder'])
|
71 |
model_args_f = Container(**yaml.safe_load(open(args_path, 'r'))['CNNEncoder_f'])
|
72 |
conformer_args = Container(**yaml.safe_load(open(args_path, 'r'))['Conformer'])
|
73 |
+
kan_args = Container(**yaml.safe_load(open(args_path, 'r'))['KAN'])
|
74 |
|
75 |
test_dataset = FFTDataset(test_dataset)
|
76 |
test_dl = DataLoader(test_dataset, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
77 |
|
78 |
+
model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
79 |
model = model.to(device)
|
80 |
+
state_dict = torch.load(data_args.checkpoint_path)
|
81 |
new_state_dict = OrderedDict()
|
82 |
for key, value in state_dict.items():
|
83 |
if key.startswith('module.'):
|
|
|
96 |
accumulation_step=1, max_iter=np.inf,
|
97 |
exp_name=f"frugal_cnnencoder_inference")
|
98 |
predictions, true_labels, acc = trainer.predict(test_dl, device=device)
|
99 |
+
# true_labels = test_dataset["label"]
|
100 |
+
|
101 |
# Make random predictions (placeholder for actual model inference)
|
102 |
print("accuracy: ", acc)
|
103 |
+
print("predictions: ", len(predictions))
|
104 |
+
print("true_labels: ", len(true_labels))
|
105 |
|
106 |
#--------------------------------------------------------------------------------------------
|
107 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
|
133 |
|
134 |
return results
|
135 |
|
136 |
+
if __name__ == "__main__":
|
137 |
# with open("../logs//token.txt", "r") as f:
|
138 |
# api_key = f.read()
|
139 |
# login(api_key)
|
140 |
# # Create a sample request object
|
141 |
+
sample_request = AudioEvaluationRequest(
|
142 |
+
dataset_name="rfcx/frugalai", # Replace with actual dataset name
|
143 |
+
test_size=0.2, # Example values
|
144 |
+
test_seed=42
|
145 |
+
)
|
146 |
#
|
147 |
+
asyncio.run(evaluate_audio(sample_request))
|
tasks/models/frugal_2025-01-21/CNNEncoder_frugal_2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tasks/models/frugal_2025-01-21/frugal_kan_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28e0188edab4879996cc960d2dc79641460b270af9c5ac7d3eacad1f5e96da39
|
3 |
+
size 1714830
|
tasks/run.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from .utils.data import FFTDataset, SplitDataset
|
3 |
+
from datasets import load_dataset
|
4 |
+
from .utils.train import Trainer
|
5 |
+
from .utils.models import CNNKan, KanEncoder
|
6 |
+
from .utils.data_utils import *
|
7 |
+
from huggingface_hub import login
|
8 |
+
import yaml
|
9 |
+
import datetime
|
10 |
+
import json
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
current_date = datetime.date.today().strftime("%Y-%m-%d")
|
15 |
+
datetime_dir = f"frugal_{current_date}"
|
16 |
+
args_dir = 'tasks/utils/config.yaml'
|
17 |
+
data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
|
18 |
+
exp_num = data_args.exp_num
|
19 |
+
model_name = data_args.model_name
|
20 |
+
model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
|
21 |
+
model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
|
22 |
+
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
|
23 |
+
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
|
24 |
+
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
|
25 |
+
os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
|
26 |
+
|
27 |
+
with open("../logs//token.txt", "r") as f:
|
28 |
+
api_key = f.read()
|
29 |
+
|
30 |
+
# local_rank, world_size, gpus_per_node = setup()
|
31 |
+
local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
32 |
+
login(api_key)
|
33 |
+
dataset = load_dataset("rfcx/frugalai", streaming=True)
|
34 |
+
|
35 |
+
train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True)
|
36 |
+
|
37 |
+
train_dl = DataLoader(train_ds, batch_size=data_args.batch_size, collate_fn=collate_fn)
|
38 |
+
|
39 |
+
val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False)
|
40 |
+
|
41 |
+
val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
|
42 |
+
|
43 |
+
test_ds = FFTDataset(dataset["test"])
|
44 |
+
test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
|
45 |
+
|
46 |
+
# for i, batch in enumerate(train_dl):
|
47 |
+
# x, x_f, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
48 |
+
# print(x.shape, x_f.shape, y.shape)
|
49 |
+
# if i > 10:
|
50 |
+
# break
|
51 |
+
# exit()
|
52 |
+
|
53 |
+
# model = DualEncoder(model_args, model_args_f, conformer_args)
|
54 |
+
# model = FasterKAN([18000,64,64,16,1])
|
55 |
+
model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
56 |
+
# model.kan.speed()
|
57 |
+
# model = KanEncoder(kan_args.get_dict())
|
58 |
+
model = model.to(local_rank)
|
59 |
+
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
60 |
+
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
61 |
+
print(f"Number of parameters: {num_params}")
|
62 |
+
|
63 |
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
64 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
65 |
+
total_steps = int(data_args.num_epochs) * 1000
|
66 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
|
67 |
+
T_max=total_steps,
|
68 |
+
eta_min=float((5e-4)/10))
|
69 |
+
|
70 |
+
# missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
|
71 |
+
# print(f"Missing keys: {missing}")
|
72 |
+
# print(f"Unexpected keys: {unexpected}")
|
73 |
+
|
74 |
+
trainer = Trainer(model=model, optimizer=optimizer,
|
75 |
+
criterion=loss_fn, output_dim=model_args.output_dim, scaler=None,
|
76 |
+
scheduler=None, train_dataloader=train_dl,
|
77 |
+
val_dataloader=val_dl, device=local_rank,
|
78 |
+
exp_num=datetime_dir, log_path=data_args.log_dir,
|
79 |
+
range_update=None,
|
80 |
+
accumulation_step=1, max_iter=np.inf,
|
81 |
+
exp_name=f"frugal_kan_{exp_num}")
|
82 |
+
fit_res = trainer.fit(num_epochs=100, device=local_rank,
|
83 |
+
early_stopping=10, only_p=False, best='loss', conf=True)
|
84 |
+
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
|
85 |
+
with open(output_filename, "w") as f:
|
86 |
+
json.dump(fit_res, f, indent=2)
|
87 |
+
preds, acc = trainer.predict(test_dl, local_rank)
|
88 |
+
print(f"Accuracy: {acc}")
|
89 |
+
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
|
tasks/utils/config.yaml
CHANGED
@@ -1,34 +1,41 @@
|
|
1 |
Data:
|
2 |
# Basics
|
3 |
-
log_dir: '/
|
4 |
# Data
|
5 |
-
dataset: "
|
6 |
-
data_dir:
|
7 |
model_name: "CNNEncoder"
|
8 |
-
batch_size:
|
9 |
-
num_epochs:
|
10 |
exp_num: 2
|
11 |
max_len_spectra: 4096
|
12 |
max_days_lc: 270
|
13 |
lc_freq: 0.0208
|
14 |
create_umap: True
|
|
|
15 |
|
16 |
CNNEncoder:
|
17 |
# Model
|
18 |
-
in_channels:
|
19 |
num_layers: 4
|
20 |
stride: 1
|
21 |
-
encoder_dims: [32,64,128
|
22 |
kernel_size: 3
|
23 |
dropout_p: 0.3
|
24 |
output_dim: 2
|
25 |
beta: 1
|
26 |
-
load_checkpoint:
|
27 |
checkpoint_num: 1
|
28 |
activation: "silu"
|
29 |
sine_w0: 1.0
|
30 |
-
avg_output:
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
CNNEncoder_f:
|
34 |
# Model
|
@@ -50,7 +57,7 @@ CNNEncoder_f:
|
|
50 |
Conformer:
|
51 |
encoder: ["mhsa_pro", "conv"]
|
52 |
timeshift: false
|
53 |
-
num_layers:
|
54 |
encoder_dim: 128
|
55 |
num_heads: 8
|
56 |
kernel_size: 3
|
|
|
1 |
Data:
|
2 |
# Basics
|
3 |
+
log_dir: 'tasks/models'
|
4 |
# Data
|
5 |
+
dataset: "FFTDataset"
|
6 |
+
data_dir: None
|
7 |
model_name: "CNNEncoder"
|
8 |
+
batch_size: 32
|
9 |
+
num_epochs: 10
|
10 |
exp_num: 2
|
11 |
max_len_spectra: 4096
|
12 |
max_days_lc: 270
|
13 |
lc_freq: 0.0208
|
14 |
create_umap: True
|
15 |
+
checkpoint_path: 'tasks/models/frugal_2025-01-21/frugal_kan_2.pth'
|
16 |
|
17 |
CNNEncoder:
|
18 |
# Model
|
19 |
+
in_channels: 2
|
20 |
num_layers: 4
|
21 |
stride: 1
|
22 |
+
encoder_dims: [32,64,128]
|
23 |
kernel_size: 3
|
24 |
dropout_p: 0.3
|
25 |
output_dim: 2
|
26 |
beta: 1
|
27 |
+
load_checkpoint: False
|
28 |
checkpoint_num: 1
|
29 |
activation: "silu"
|
30 |
sine_w0: 1.0
|
31 |
+
avg_output: False
|
32 |
+
|
33 |
+
KAN:
|
34 |
+
layers_hidden: [1125,32,8,8,1]
|
35 |
+
grid_min: -1.2
|
36 |
+
grid_max: 1.2
|
37 |
+
num_grids: 8
|
38 |
+
exponent: 2
|
39 |
|
40 |
CNNEncoder_f:
|
41 |
# Model
|
|
|
57 |
Conformer:
|
58 |
encoder: ["mhsa_pro", "conv"]
|
59 |
timeshift: false
|
60 |
+
num_layers: 4
|
61 |
encoder_dim: 128
|
62 |
num_heads: 8
|
63 |
kernel_size: 3
|
tasks/utils/data.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import torch
|
2 |
from torch.utils.data import IterableDataset
|
3 |
from torch.fft import fft
|
|
|
4 |
from itertools import tee
|
5 |
import random
|
6 |
import torchaudio.transforms as T
|
@@ -24,20 +25,25 @@ class SplitDataset(IterableDataset):
|
|
24 |
|
25 |
|
26 |
class FFTDataset(IterableDataset):
|
27 |
-
def __init__(self, original_dataset, orig_sample_rate=12000, target_sample_rate=
|
28 |
self.dataset = original_dataset
|
29 |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
|
|
30 |
|
31 |
def __iter__(self):
|
32 |
for item in self.dataset:
|
33 |
# Assuming your audio data is in item['audio']
|
34 |
# Modify this based on your actual data structure
|
35 |
audio_data = torch.tensor(item['audio']['array']).float()
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
40 |
|
41 |
# Update the item with FFT data
|
42 |
item['audio']['fft'] = fft_data
|
|
|
43 |
yield item
|
|
|
1 |
import torch
|
2 |
from torch.utils.data import IterableDataset
|
3 |
from torch.fft import fft
|
4 |
+
import torch.nn.functional as F
|
5 |
from itertools import tee
|
6 |
import random
|
7 |
import torchaudio.transforms as T
|
|
|
25 |
|
26 |
|
27 |
class FFTDataset(IterableDataset):
|
28 |
+
def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
|
29 |
self.dataset = original_dataset
|
30 |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
31 |
+
self.max_len = max_len
|
32 |
|
33 |
def __iter__(self):
|
34 |
for item in self.dataset:
|
35 |
# Assuming your audio data is in item['audio']
|
36 |
# Modify this based on your actual data structure
|
37 |
audio_data = torch.tensor(item['audio']['array']).float()
|
38 |
+
# pad audio
|
39 |
+
# if len(audio_data) == 0:
|
40 |
+
# continue
|
41 |
+
pad_len = self.max_len - len(audio_data)
|
42 |
+
audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
|
43 |
+
audio_data = self.resampler(audio_data)
|
44 |
+
fft_data = fft(audio_data)
|
45 |
|
46 |
# Update the item with FFT data
|
47 |
item['audio']['fft'] = fft_data
|
48 |
+
item['audio']['array'] = audio_data
|
49 |
yield item
|
tasks/utils/kan/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .fasterkan import FasterKAN, FasterKANLayer, FasterKANvolver
|
tasks/utils/kan/fasterkan.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from typing import *
|
6 |
+
from torch.autograd import Function
|
7 |
+
from .feature_extractor import EnhancedFeatureExtractor
|
8 |
+
from .fasterkan_layers import FasterKANLayer
|
9 |
+
|
10 |
+
class FasterKAN(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
layers_hidden: List[int],
|
14 |
+
grid_min: float = -1.2,
|
15 |
+
grid_max: float = 1.2,
|
16 |
+
num_grids: int = 8,
|
17 |
+
exponent: int = 2,
|
18 |
+
inv_denominator: float = 0.5,
|
19 |
+
train_grid: bool = False,
|
20 |
+
train_inv_denominator: bool = False,
|
21 |
+
#use_base_update: bool = True,
|
22 |
+
base_activation = None,
|
23 |
+
spline_weight_init_scale: float = 1.0,
|
24 |
+
) -> None:
|
25 |
+
super().__init__()
|
26 |
+
self.layers = nn.ModuleList([
|
27 |
+
FasterKANLayer(
|
28 |
+
in_dim, out_dim,
|
29 |
+
grid_min=grid_min,
|
30 |
+
grid_max=grid_max,
|
31 |
+
num_grids=num_grids,
|
32 |
+
exponent = exponent,
|
33 |
+
inv_denominator = inv_denominator,
|
34 |
+
train_grid = train_grid ,
|
35 |
+
train_inv_denominator = train_inv_denominator,
|
36 |
+
#use_base_update=use_base_update,
|
37 |
+
base_activation=base_activation,
|
38 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
39 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
40 |
+
])
|
41 |
+
#print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
42 |
+
#print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
43 |
+
#print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
44 |
+
|
45 |
+
#print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
|
46 |
+
#print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
for layer in self.layers:
|
50 |
+
#print("FasterKAN layer: \n", layer)
|
51 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
52 |
+
x = layer(x)
|
53 |
+
return x
|
54 |
+
|
55 |
+
class FasterKANvolver(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
layers_hidden: List[int],
|
59 |
+
grid_min: float = -1.2,
|
60 |
+
grid_max: float = 0.2,
|
61 |
+
num_grids: int = 8,
|
62 |
+
exponent: int = 2,
|
63 |
+
inv_denominator: float = 0.5,
|
64 |
+
train_grid: bool = False,
|
65 |
+
train_inv_denominator: bool = False,
|
66 |
+
#use_base_update: bool = True,
|
67 |
+
base_activation = None,
|
68 |
+
spline_weight_init_scale: float = 1.0,
|
69 |
+
view = [-1, 1, 28, 28],
|
70 |
+
) -> None:
|
71 |
+
super(FasterKANvolver, self).__init__()
|
72 |
+
|
73 |
+
self.view = view
|
74 |
+
# Feature extractor with Convolutional layers
|
75 |
+
self.feature_extractor = EnhancedFeatureExtractor(colors = view[1])
|
76 |
+
"""
|
77 |
+
nn.Sequential(
|
78 |
+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
|
79 |
+
nn.ReLU(),
|
80 |
+
nn.MaxPool2d(2, 2),
|
81 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
|
82 |
+
nn.ReLU(),
|
83 |
+
nn.MaxPool2d(2, 2)
|
84 |
+
)
|
85 |
+
"""
|
86 |
+
# Calculate the flattened feature size after convolutional layers
|
87 |
+
flat_features = 256 # XX channels, image size reduced to YxY
|
88 |
+
|
89 |
+
# Update layers_hidden with the correct input size from conv layers
|
90 |
+
layers_hidden = [flat_features] + layers_hidden
|
91 |
+
#print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
|
92 |
+
#print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
93 |
+
#print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
94 |
+
#print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
95 |
+
|
96 |
+
# Define the FasterKAN layers
|
97 |
+
self.faster_kan_layers = nn.ModuleList([
|
98 |
+
FasterKANLayer(
|
99 |
+
in_dim, out_dim,
|
100 |
+
grid_min=grid_min,
|
101 |
+
grid_max=grid_max,
|
102 |
+
num_grids=num_grids,
|
103 |
+
exponent=exponent,
|
104 |
+
inv_denominator = 0.5,
|
105 |
+
train_grid = False,
|
106 |
+
train_inv_denominator = False,
|
107 |
+
#use_base_update=use_base_update,
|
108 |
+
base_activation=base_activation,
|
109 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
110 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
111 |
+
])
|
112 |
+
#print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
|
113 |
+
#print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
|
114 |
+
|
115 |
+
def forward(self, x):
|
116 |
+
# Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
|
117 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
118 |
+
# Handle different input shapes based on the length of view
|
119 |
+
x = x.view(self.view[0], self.view[1], self.view[2], self.view[3])
|
120 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
121 |
+
# Apply convolutional layers
|
122 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
123 |
+
x = self.feature_extractor(x)
|
124 |
+
#print(f"FasterKAN x after feature_extractor shape: {x.shape}")
|
125 |
+
x = x.view(x.size(0), -1) # Flatten the output from the conv layers
|
126 |
+
#rint(f"FasterKAN x shape: {x.shape}")
|
127 |
+
|
128 |
+
# Pass through FasterKAN layers
|
129 |
+
for layer in self.faster_kan_layers:
|
130 |
+
#print("FasterKAN layer: \n", layer)
|
131 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
132 |
+
x = layer(x)
|
133 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
134 |
+
|
135 |
+
return x
|
tasks/utils/kan/fasterkan_basis.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from typing import *
|
6 |
+
from torch.autograd import Function
|
7 |
+
|
8 |
+
class RSWAFFunction(Function):
|
9 |
+
@staticmethod
|
10 |
+
def forward(ctx, input, grid, inv_denominator, train_grid, train_inv_denominator):
|
11 |
+
# Compute the forward pass
|
12 |
+
#print('\n')
|
13 |
+
#print(f"Forward pass - grid: {(grid[0].item(),grid[-1].item())}, inv_denominator: {inv_denominator.item()}")
|
14 |
+
|
15 |
+
#print(f"grid.shape: {grid.shape }")
|
16 |
+
#print(f"grid: {(grid[0],grid[-1]) }")
|
17 |
+
#print(f"inv_denominator.shape: {inv_denominator.shape }")
|
18 |
+
#print(f"inv_denominator: {inv_denominator }")
|
19 |
+
diff = (input[..., None] - grid)
|
20 |
+
diff_mul = diff.mul(inv_denominator)
|
21 |
+
tanh_diff = torch.tanh(diff)
|
22 |
+
tanh_diff_deriviative = -tanh_diff.mul(tanh_diff) + 1 # sech^2(x) = 1 - tanh^2(x)
|
23 |
+
|
24 |
+
# Save tensors for backward pass
|
25 |
+
ctx.save_for_backward(input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator)
|
26 |
+
ctx.train_grid = train_grid
|
27 |
+
ctx.train_inv_denominator = train_inv_denominator
|
28 |
+
|
29 |
+
return tanh_diff_deriviative
|
30 |
+
|
31 |
+
##### SOS NOT SURE HOW grad_inv_denominator, grad_grid ARE CALCULATED CORRECTLY YET
|
32 |
+
##### MUST CHECK https://github.com/pytorch/pytorch/issues/74802
|
33 |
+
##### MUST CHECK https://www.changjiangcai.com/studynotes/2020-10-18-Custom-Function-Extending-PyTorch/
|
34 |
+
##### MUST CHECK https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html
|
35 |
+
##### MUST CHECK https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html
|
36 |
+
##### MUST CHECK https://gist.github.com/Hanrui-Wang/bf225dc0ccb91cdce160539c0acc853a
|
37 |
+
|
38 |
+
@staticmethod
|
39 |
+
def backward(ctx, grad_output):
|
40 |
+
# Retrieve saved tensors
|
41 |
+
input, tanh_diff, tanh_diff_deriviative, diff, inv_denominator = ctx.saved_tensors
|
42 |
+
grad_grid = None
|
43 |
+
grad_inv_denominator = None
|
44 |
+
|
45 |
+
#print(f"tanh_diff_deriviative shape: {tanh_diff_deriviative.shape }")
|
46 |
+
#print(f"tanh_diff shape: {tanh_diff.shape }")
|
47 |
+
#print(f"grad_output shape: {grad_output.shape }")
|
48 |
+
|
49 |
+
# Compute the backward pass for the input
|
50 |
+
grad_input = -2 * tanh_diff * tanh_diff_deriviative * grad_output
|
51 |
+
#print(f"Backward pass 1 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
|
52 |
+
#print(f"grad_input shape: {grad_input.shape }")
|
53 |
+
#print(f"grad_input.sum(dim=-1): {grad_input.sum(dim=-1).shape}")
|
54 |
+
grad_input = grad_input.sum(dim=-1).mul(inv_denominator)
|
55 |
+
#print(f"Backward pass 2 - grad_input: {(grad_input.min().item(), grad_input.max().item())}")
|
56 |
+
#print(f"grad_input: {grad_input}")
|
57 |
+
#print(f"grad_input shape: {grad_input.shape }")
|
58 |
+
|
59 |
+
# Compute the backward pass for grid
|
60 |
+
if ctx.train_grid:
|
61 |
+
#print('\n')
|
62 |
+
#print(f"grad_grid shape: {grad_grid.shape }")
|
63 |
+
grad_grid = -inv_denominator * grad_output.sum(dim=0).sum(dim=0)#-(inv_denominator * grad_output * tanh_diff_deriviative).sum(dim=0) #-inv_denominator * grad_output.sum(dim=0).sum(dim=0)
|
64 |
+
#print(f"Backward pass - grad_grid: {(grad_grid[0].item(),grad_grid[-1].item())}")
|
65 |
+
#print(f"grad_grid.shape: {grad_grid.shape }")
|
66 |
+
#print(f"grad_grid: {(grad_grid[0],grad_grid[-1]) }")
|
67 |
+
#print(f"inv_denominator shape: {inv_denominator.shape }")
|
68 |
+
#print(f"grad_grid shape: {grad_grid.shape }")
|
69 |
+
|
70 |
+
# Compute the backward pass for inv_denominator
|
71 |
+
if ctx.train_inv_denominator:
|
72 |
+
grad_inv_denominator = (grad_output* diff).sum() #(grad_output * diff * tanh_diff_deriviative).sum() #(grad_output* diff).sum()
|
73 |
+
#print(f"Backward pass - grad_inv_denominator: {grad_inv_denominator.item()}")
|
74 |
+
#print(f"diff shape: {diff.shape }")
|
75 |
+
|
76 |
+
#print(f"grad_inv_denominator shape: {grad_inv_denominator.shape }")
|
77 |
+
#print(f"grad_inv_denominator : {grad_inv_denominator }")
|
78 |
+
|
79 |
+
return grad_input, grad_grid, grad_inv_denominator, None, None # same number as tensors or parameters
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
class ReflectionalSwitchFunction(nn.Module):
|
84 |
+
def __init__(
|
85 |
+
self,
|
86 |
+
grid_min: float = -1.2,
|
87 |
+
grid_max: float = 0.2,
|
88 |
+
num_grids: int = 8,
|
89 |
+
exponent: int = 2,
|
90 |
+
inv_denominator: float = 0.5,
|
91 |
+
train_grid: bool = False,
|
92 |
+
train_inv_denominator: bool = False,
|
93 |
+
):
|
94 |
+
super().__init__()
|
95 |
+
grid = torch.linspace(grid_min, grid_max, num_grids)
|
96 |
+
self.train_grid = torch.tensor(train_grid, dtype=torch.bool)
|
97 |
+
self.train_inv_denominator = torch.tensor(train_inv_denominator, dtype=torch.bool)
|
98 |
+
self.grid = torch.nn.Parameter(grid, requires_grad=train_grid)
|
99 |
+
#print(f"grid initial shape: {self.grid.shape }")
|
100 |
+
self.inv_denominator = torch.nn.Parameter(torch.tensor(inv_denominator, dtype=torch.float32), requires_grad=train_inv_denominator) # Cache the inverse of the denominator
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
return RSWAFFunction.apply(x, self.grid, self.inv_denominator, self.train_grid, self.train_inv_denominator)
|
104 |
+
|
105 |
+
|
106 |
+
class SplineLinear(nn.Linear):
|
107 |
+
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
|
108 |
+
self.init_scale = init_scale
|
109 |
+
super().__init__(in_features, out_features, bias=False, **kw)
|
110 |
+
|
111 |
+
def reset_parameters(self) -> None:
|
112 |
+
nn.init.xavier_uniform_(self.weight) # Using Xavier Uniform initialization
|
tasks/utils/kan/fasterkan_layers.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from typing import *
|
6 |
+
from torch.autograd import Function
|
7 |
+
from .fasterkan_basis import ReflectionalSwitchFunction, SplineLinear
|
8 |
+
|
9 |
+
class FasterKANLayer(nn.Module):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
input_dim: int,
|
13 |
+
output_dim: int,
|
14 |
+
grid_min: float = -1.2,
|
15 |
+
grid_max: float = 0.2,
|
16 |
+
num_grids: int = 8,
|
17 |
+
exponent: int = 2,
|
18 |
+
inv_denominator: float = 0.5,
|
19 |
+
train_grid: bool = False,
|
20 |
+
train_inv_denominator: bool = False,
|
21 |
+
#use_base_update: bool = True,
|
22 |
+
base_activation = F.silu,
|
23 |
+
spline_weight_init_scale: float = 0.667,
|
24 |
+
) -> None:
|
25 |
+
super().__init__()
|
26 |
+
self.layernorm = nn.LayerNorm(input_dim)
|
27 |
+
self.rbf = ReflectionalSwitchFunction(grid_min, grid_max, num_grids, exponent, inv_denominator, train_grid, train_inv_denominator)
|
28 |
+
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
|
29 |
+
#self.use_base_update = use_base_update
|
30 |
+
#if use_base_update:
|
31 |
+
# self.base_activation = base_activation
|
32 |
+
# self.base_linear = nn.Linear(input_dim, output_dim)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
#print("Shape before LayerNorm:", x.shape) # Debugging line to check the input shape
|
36 |
+
x = self.layernorm(x)
|
37 |
+
#print("Shape After LayerNorm:", x.shape)
|
38 |
+
spline_basis = self.rbf(x).view(x.shape[0], -1)
|
39 |
+
#print("spline_basis:", spline_basis.shape)
|
40 |
+
|
41 |
+
#print("-------------------------")
|
42 |
+
#ret = 0
|
43 |
+
ret = self.spline_linear(spline_basis)
|
44 |
+
#print("spline_basis.shape[:-2]:", spline_basis.shape[:-2])
|
45 |
+
#print("*spline_basis.shape[:-2]:", *spline_basis.shape[:-2])
|
46 |
+
#print("spline_basis.view(*spline_basis.shape[:-2], -1):", spline_basis.view(*spline_basis.shape[:-2], -1).shape)
|
47 |
+
#print("ret:", ret.shape)
|
48 |
+
#print("-------------------------")
|
49 |
+
#if self.use_base_update:
|
50 |
+
#base = self.base_linear(self.base_activation(x))
|
51 |
+
#print("self.base_activation(x):", self.base_activation(x).shape)
|
52 |
+
#print("base:", base.shape)
|
53 |
+
#print("@@@@@@@@@")
|
54 |
+
#ret += base
|
55 |
+
return ret
|
56 |
+
|
57 |
+
|
58 |
+
#spline_basis = spline_basis.reshape(x.shape[0], -1) # Reshape to [batch_size, input_dim * num_grids]
|
59 |
+
#print("spline_basis:", spline_basis.shape)
|
60 |
+
|
61 |
+
#spline_weight = self.spline_weight.view(-1, self.spline_weight.shape[0]) # Reshape to [input_dim * num_grids, output_dim]
|
62 |
+
#print("spline_weight:", spline_weight.shape)
|
63 |
+
|
64 |
+
#spline = torch.matmul(spline_basis, spline_weight) # Resulting shape: [batch_size, output_dim]
|
65 |
+
|
66 |
+
#print("-------------------------")
|
67 |
+
#print("Base shape:", base.shape)
|
68 |
+
#print("Spline shape:", spline.shape)
|
69 |
+
#print("@@@@@@@@@")
|
70 |
+
|
71 |
+
|
72 |
+
class FasterKAN(nn.Module):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
layers_hidden: List[int],
|
76 |
+
grid_min: float = -1.2,
|
77 |
+
grid_max: float = 0.2,
|
78 |
+
num_grids: int = 8,
|
79 |
+
exponent: int = 2,
|
80 |
+
inv_denominator: float = 0.5,
|
81 |
+
train_grid: bool = False,
|
82 |
+
train_inv_denominator: bool = False,
|
83 |
+
#use_base_update: bool = True,
|
84 |
+
base_activation = None,
|
85 |
+
spline_weight_init_scale: float = 1.0,
|
86 |
+
) -> None:
|
87 |
+
super().__init__()
|
88 |
+
self.layers = nn.ModuleList([
|
89 |
+
FasterKANLayer(
|
90 |
+
in_dim, out_dim,
|
91 |
+
grid_min=grid_min,
|
92 |
+
grid_max=grid_max,
|
93 |
+
num_grids=num_grids,
|
94 |
+
exponent = exponent,
|
95 |
+
inv_denominator = inv_denominator,
|
96 |
+
train_grid = train_grid ,
|
97 |
+
train_inv_denominator = train_inv_denominator,
|
98 |
+
#use_base_update=use_base_update,
|
99 |
+
base_activation=base_activation,
|
100 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
101 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
102 |
+
])
|
103 |
+
#print(f"FasterKAN layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
104 |
+
#print(f"FasterKAN layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
105 |
+
#print("FasterKAN zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
106 |
+
|
107 |
+
#print(f"FasterKAN self.faster_kan_layers shape: \n", len(self.layers))
|
108 |
+
#print(f"FasterKAN self.faster_kan_layers: \n", self.layers)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
for layer in self.layers:
|
112 |
+
#print("FasterKAN layer: \n", layer)
|
113 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
114 |
+
x = layer(x)
|
115 |
+
return x
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
class BasicResBlock(nn.Module):
|
120 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
121 |
+
super(BasicResBlock, self).__init__()
|
122 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
123 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
124 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
125 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
126 |
+
|
127 |
+
self.downsample = nn.Sequential()
|
128 |
+
if stride != 1 or in_channels != out_channels:
|
129 |
+
self.downsample = nn.Sequential(
|
130 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
131 |
+
nn.BatchNorm2d(out_channels)
|
132 |
+
)
|
133 |
+
|
134 |
+
def forward(self, x):
|
135 |
+
identity = self.downsample(x)
|
136 |
+
|
137 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
138 |
+
out = self.bn2(self.conv2(out))
|
139 |
+
out += identity
|
140 |
+
out = F.relu(out)
|
141 |
+
|
142 |
+
return out
|
143 |
+
|
144 |
+
class SEBlock(nn.Module):
|
145 |
+
def __init__(self, channel, reduction=16):
|
146 |
+
super(SEBlock, self).__init__()
|
147 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
148 |
+
self.fc = nn.Sequential(
|
149 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
150 |
+
nn.ReLU(inplace=True),
|
151 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
152 |
+
nn.Sigmoid()
|
153 |
+
)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
b, c, _, _ = x.size()
|
157 |
+
y = self.avg_pool(x).view(b, c)
|
158 |
+
y = self.fc(y).view(b, c, 1, 1)
|
159 |
+
return x * y.expand_as(x)
|
160 |
+
|
161 |
+
|
162 |
+
class DepthwiseSeparableConv(nn.Module):
|
163 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
164 |
+
super(DepthwiseSeparableConv, self).__init__()
|
165 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
166 |
+
stride=stride, padding=padding, groups=in_channels)
|
167 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
168 |
+
|
169 |
+
def forward(self, x):
|
170 |
+
x = self.depthwise(x)
|
171 |
+
x = self.pointwise(x)
|
172 |
+
return x
|
173 |
+
|
174 |
+
class SelfAttention(nn.Module):
|
175 |
+
def __init__(self, in_channels):
|
176 |
+
super(SelfAttention, self).__init__()
|
177 |
+
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
178 |
+
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
179 |
+
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
180 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
batch_size, C, width, height = x.size()
|
184 |
+
proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
|
185 |
+
proj_key = self.key_conv(x).view(batch_size, -1, width * height)
|
186 |
+
energy = torch.bmm(proj_query, proj_key)
|
187 |
+
attention = F.softmax(energy, dim=-1)
|
188 |
+
proj_value = self.value_conv(x).view(batch_size, -1, width * height)
|
189 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
190 |
+
out = out.view(batch_size, C, width, height)
|
191 |
+
out = self.gamma * out + x
|
192 |
+
return out
|
193 |
+
|
194 |
+
class EnhancedFeatureExtractor(nn.Module):
|
195 |
+
def __init__(self):
|
196 |
+
super(EnhancedFeatureExtractor, self).__init__()
|
197 |
+
self.initial_layers = nn.Sequential(
|
198 |
+
nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), # Increased number of filters
|
199 |
+
nn.ReLU(),
|
200 |
+
nn.BatchNorm2d(32), # Added Batch Normalization
|
201 |
+
nn.MaxPool2d(2, 2),
|
202 |
+
nn.Dropout(0.25), # Added Dropout
|
203 |
+
BasicResBlock(32, 64),
|
204 |
+
SEBlock(64, reduction=16), # Squeeze-and-Excitation block
|
205 |
+
nn.MaxPool2d(2, 2),
|
206 |
+
nn.Dropout(0.25), # Added Dropout
|
207 |
+
DepthwiseSeparableConv(64, 128, kernel_size=3), # Increased number of filters
|
208 |
+
nn.ReLU(),
|
209 |
+
BasicResBlock(128, 256),
|
210 |
+
SEBlock(256, reduction=16),
|
211 |
+
nn.MaxPool2d(2, 2),
|
212 |
+
nn.Dropout(0.25), # Added Dropout
|
213 |
+
SelfAttention(256), # Added Self-Attention layer
|
214 |
+
)
|
215 |
+
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Global Average Pooling
|
216 |
+
|
217 |
+
def forward(self, x):
|
218 |
+
x = self.initial_layers(x)
|
219 |
+
x = self.global_avg_pool(x)
|
220 |
+
x = x.view(x.size(0), -1) # Flatten the output for fully connected layers
|
221 |
+
return x
|
222 |
+
|
223 |
+
class FasterKANvolver(nn.Module):
|
224 |
+
def __init__(
|
225 |
+
self,
|
226 |
+
layers_hidden: List[int],
|
227 |
+
grid_min: float = -1.2,
|
228 |
+
grid_max: float = 0.2,
|
229 |
+
num_grids: int = 8,
|
230 |
+
exponent: int = 2,
|
231 |
+
inv_denominator: float = 0.5,
|
232 |
+
train_grid: bool = False,
|
233 |
+
train_inv_denominator: bool = False,
|
234 |
+
#use_base_update: bool = True,
|
235 |
+
base_activation = None,
|
236 |
+
spline_weight_init_scale: float = 1.0,
|
237 |
+
) -> None:
|
238 |
+
super(FasterKANvolver, self).__init__()
|
239 |
+
|
240 |
+
# Feature extractor with Convolutional layers
|
241 |
+
self.feature_extractor = EnhancedFeatureExtractor()
|
242 |
+
"""
|
243 |
+
nn.Sequential(
|
244 |
+
nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1), # 1 input channel (grayscale), 16 output channels
|
245 |
+
nn.ReLU(),
|
246 |
+
nn.MaxPool2d(2, 2),
|
247 |
+
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
|
248 |
+
nn.ReLU(),
|
249 |
+
nn.MaxPool2d(2, 2)
|
250 |
+
)
|
251 |
+
"""
|
252 |
+
|
253 |
+
# Calculate the flattened feature size after convolutional layers
|
254 |
+
flat_features = 256 # XX channels, image size reduced to YxY
|
255 |
+
|
256 |
+
# Update layers_hidden with the correct input size from conv layers
|
257 |
+
layers_hidden = [flat_features] + layers_hidden
|
258 |
+
#print(f"FasterKANvolver layers_hidden shape: \n", layers_hidden)
|
259 |
+
#print(f"FasterKANvolver layers_hidden[1:] shape: ", len(layers_hidden[1:]))
|
260 |
+
#print(f"FasterKANvolver layers_hidden[:-1] shape: ", len(layers_hidden[:-1]))
|
261 |
+
#print("FasterKANvolver zip shape: \n", *[(in_dim, out_dim) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])])
|
262 |
+
|
263 |
+
# Define the FasterKAN layers
|
264 |
+
self.faster_kan_layers = nn.ModuleList([
|
265 |
+
FasterKANLayer(
|
266 |
+
in_dim, out_dim,
|
267 |
+
grid_min=grid_min,
|
268 |
+
grid_max=grid_max,
|
269 |
+
num_grids=num_grids,
|
270 |
+
exponent=exponent,
|
271 |
+
inv_denominator = 0.5,
|
272 |
+
train_grid = False,
|
273 |
+
train_inv_denominator = False,
|
274 |
+
#use_base_update=use_base_update,
|
275 |
+
base_activation=base_activation,
|
276 |
+
spline_weight_init_scale=spline_weight_init_scale,
|
277 |
+
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
|
278 |
+
])
|
279 |
+
#print(f"FasterKANvolver self.faster_kan_layers shape: \n", len(self.faster_kan_layers))
|
280 |
+
#print(f"FasterKANvolver self.faster_kan_layers: \n", self.faster_kan_layers)
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
# Reshape input from [batch_size, 784] to [batch_size, 1, 28, 28] for MNIST [batch_size, 1, 32, 32] for C
|
284 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
285 |
+
x = x.view(-1, 3, 32,32)
|
286 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
287 |
+
# Apply convolutional layers
|
288 |
+
#print(f"FasterKAN x view shape: {x.shape}")
|
289 |
+
x = self.feature_extractor(x)
|
290 |
+
#print(f"FasterKAN x after feature_extractor shape: {x.shape}")
|
291 |
+
x = x.view(x.size(0), -1) # Flatten the output from the conv layers
|
292 |
+
#rint(f"FasterKAN x shape: {x.shape}")
|
293 |
+
|
294 |
+
# Pass through FasterKAN layers
|
295 |
+
for layer in self.faster_kan_layers:
|
296 |
+
#print("FasterKAN layer: \n", layer)
|
297 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
298 |
+
x = layer(x)
|
299 |
+
#print(f"FasterKAN x shape: {x.shape}")
|
300 |
+
|
301 |
+
return x
|
tasks/utils/kan/feature_extractor.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
from typing import *
|
6 |
+
from torch.autograd import Function
|
7 |
+
|
8 |
+
|
9 |
+
class BasicResBlock(nn.Module):
|
10 |
+
def __init__(self, in_channels, out_channels, stride=1):
|
11 |
+
super(BasicResBlock, self).__init__()
|
12 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
13 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
14 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
15 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
16 |
+
|
17 |
+
self.downsample = nn.Sequential()
|
18 |
+
if stride != 1 or in_channels != out_channels:
|
19 |
+
self.downsample = nn.Sequential(
|
20 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
21 |
+
nn.BatchNorm2d(out_channels)
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
identity = self.downsample(x)
|
26 |
+
|
27 |
+
out = F.relu(self.bn1(self.conv1(x)))
|
28 |
+
out = self.bn2(self.conv2(out))
|
29 |
+
out += identity
|
30 |
+
out = F.relu(out)
|
31 |
+
|
32 |
+
return out
|
33 |
+
|
34 |
+
class SEBlock(nn.Module):
|
35 |
+
def __init__(self, channel, reduction=16):
|
36 |
+
super(SEBlock, self).__init__()
|
37 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
38 |
+
self.fc = nn.Sequential(
|
39 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
40 |
+
nn.ReLU(inplace=True),
|
41 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
42 |
+
nn.Sigmoid()
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
b, c, _, _ = x.size()
|
47 |
+
y = self.avg_pool(x).view(b, c)
|
48 |
+
y = self.fc(y).view(b, c, 1, 1)
|
49 |
+
return x * y.expand_as(x)
|
50 |
+
|
51 |
+
|
52 |
+
class DepthwiseSeparableConv(nn.Module):
|
53 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
|
54 |
+
super(DepthwiseSeparableConv, self).__init__()
|
55 |
+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size,
|
56 |
+
stride=stride, padding=padding, groups=in_channels)
|
57 |
+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
58 |
+
|
59 |
+
def forward(self, x):
|
60 |
+
x = self.depthwise(x)
|
61 |
+
x = self.pointwise(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
class SelfAttention(nn.Module):
|
65 |
+
def __init__(self, in_channels):
|
66 |
+
super(SelfAttention, self).__init__()
|
67 |
+
self.query_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
68 |
+
self.key_conv = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
|
69 |
+
self.value_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
70 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
batch_size, C, width, height = x.size()
|
74 |
+
proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)
|
75 |
+
proj_key = self.key_conv(x).view(batch_size, -1, width * height)
|
76 |
+
energy = torch.bmm(proj_query, proj_key)
|
77 |
+
attention = F.softmax(energy, dim=-1)
|
78 |
+
proj_value = self.value_conv(x).view(batch_size, -1, width * height)
|
79 |
+
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
|
80 |
+
out = out.view(batch_size, C, width, height)
|
81 |
+
out = self.gamma * out + x
|
82 |
+
return out
|
83 |
+
|
84 |
+
class EnhancedFeatureExtractor(nn.Module):
|
85 |
+
def __init__(self,
|
86 |
+
colors = 3):
|
87 |
+
super(EnhancedFeatureExtractor, self).__init__()
|
88 |
+
self.initial_layers = nn.Sequential(
|
89 |
+
nn.Conv2d(colors, 32, kernel_size=3, stride=1, padding=1), # Increased number of filters
|
90 |
+
nn.ReLU(),
|
91 |
+
nn.BatchNorm2d(32), # Added Batch Normalization
|
92 |
+
nn.MaxPool2d(2, 2),
|
93 |
+
nn.Dropout(0.25), # Added Dropout
|
94 |
+
BasicResBlock(32, 64),
|
95 |
+
SEBlock(64, reduction=16), # Squeeze-and-Excitation block
|
96 |
+
nn.MaxPool2d(2, 2),
|
97 |
+
nn.Dropout(0.25), # Added Dropout
|
98 |
+
DepthwiseSeparableConv(64, 128, kernel_size=3), # Increased number of filters
|
99 |
+
nn.ReLU(),
|
100 |
+
BasicResBlock(128, 256),
|
101 |
+
SEBlock(256, reduction=16),
|
102 |
+
nn.MaxPool2d(2, 2),
|
103 |
+
nn.Dropout(0.25), # Added Dropout
|
104 |
+
SelfAttention(256), # Added Self-Attention layer
|
105 |
+
)
|
106 |
+
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # Global Average Pooling
|
107 |
+
|
108 |
+
def forward(self, x):
|
109 |
+
x = self.initial_layers(x)
|
110 |
+
x = self.global_avg_pool(x)
|
111 |
+
x = x.view(x.size(0), -1) # Flatten the output for fully connected layers
|
112 |
+
return x
|
tasks/utils/models.py
CHANGED
@@ -2,6 +2,8 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
from .Modules.conformer import ConformerEncoder, ConformerDecoder
|
4 |
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
|
|
|
|
5 |
|
6 |
class ConvBlock(nn.Module):
|
7 |
def __init__(self, args, num_layer) -> None:
|
@@ -111,4 +113,29 @@ class DualEncoder(nn.Module):
|
|
111 |
x1 = self.encoder_x(x)
|
112 |
x2, _ = self.encoder_f(x)
|
113 |
logits = torch.cat([x1, x2], dim=-1)
|
114 |
-
return self.regressor(logits).squeeze()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
from .Modules.conformer import ConformerEncoder, ConformerDecoder
|
4 |
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
5 |
+
from .kan.fasterkan import FasterKAN
|
6 |
+
from kan import KAN
|
7 |
|
8 |
class ConvBlock(nn.Module):
|
9 |
def __init__(self, args, num_layer) -> None:
|
|
|
113 |
x1 = self.encoder_x(x)
|
114 |
x2, _ = self.encoder_f(x)
|
115 |
logits = torch.cat([x1, x2], dim=-1)
|
116 |
+
return self.regressor(logits).squeeze()
|
117 |
+
|
118 |
+
class CNNKan(nn.Module):
|
119 |
+
def __init__(self, args, conformer_args, kan_args):
|
120 |
+
super().__init__()
|
121 |
+
self.backbone = CNNEncoder(args)
|
122 |
+
# self.kan = KAN(width=kan_args['layers_hidden'])
|
123 |
+
self.kan = FasterKAN(**kan_args)
|
124 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
125 |
+
x = self.backbone(x)
|
126 |
+
x = x.mean(dim=1)
|
127 |
+
return self.kan(x)
|
128 |
+
|
129 |
+
class KanEncoder(nn.Module):
|
130 |
+
def __init__(self, args):
|
131 |
+
super().__init__()
|
132 |
+
self.kan_x = FasterKAN(**args)
|
133 |
+
self.kan_f = FasterKAN(**args)
|
134 |
+
self.kan_out = FasterKAN(layers_hidden=[args['layers_hidden'][-1]*2, 8,8,1])
|
135 |
+
|
136 |
+
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
|
137 |
+
x = self.kan_x(x)
|
138 |
+
f = self.kan_f(f)
|
139 |
+
out = torch.cat([x, f], dim=-1)
|
140 |
+
return self.kan_out(out)
|
141 |
+
|
tasks/utils/train.py
CHANGED
@@ -74,8 +74,8 @@ class Trainer(object):
|
|
74 |
lrs = []
|
75 |
# self.optim_params['lr_history'] = []
|
76 |
epochs_without_improvement = 0
|
77 |
-
main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
|
78 |
-
|
79 |
print(f"Starting training for {num_epochs} epochs")
|
80 |
print("is main process: ", main_proccess, flush=True)
|
81 |
global_time = time.time()
|
@@ -221,7 +221,8 @@ class Trainer(object):
|
|
221 |
x = x.to(device).float()
|
222 |
fft = fft.to(device).float()
|
223 |
y = y.to(device).float()
|
224 |
-
|
|
|
225 |
loss = self.criterion(y_pred, y)
|
226 |
loss.backward()
|
227 |
self.optimizer.step()
|
@@ -230,7 +231,7 @@ class Trainer(object):
|
|
230 |
# get predicted classes
|
231 |
probs = torch.sigmoid(y_pred)
|
232 |
cls_pred = (probs > 0.5).float()
|
233 |
-
acc = (cls_pred == y).sum()
|
234 |
return loss, acc, y
|
235 |
|
236 |
def eval_epoch(self, device, epoch):
|
@@ -257,10 +258,11 @@ class Trainer(object):
|
|
257 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
258 |
x = x.to(device).float()
|
259 |
fft = fft.to(device).float()
|
|
|
260 |
y = y.to(device).float()
|
261 |
with torch.no_grad():
|
262 |
-
y_pred = self.model(
|
263 |
-
loss = self.criterion(y_pred, y)
|
264 |
probs = torch.sigmoid(y_pred)
|
265 |
cls_pred = (probs > 0.5).float()
|
266 |
acc = (cls_pred == y).sum()
|
@@ -280,15 +282,16 @@ class Trainer(object):
|
|
280 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
281 |
x = x.to(device).float()
|
282 |
fft = fft.to(device).float()
|
|
|
283 |
y = y.to(device).float()
|
284 |
with torch.no_grad():
|
285 |
-
y_pred = self.model(
|
286 |
loss = self.criterion(y_pred, y)
|
287 |
probs = torch.sigmoid(y_pred)
|
288 |
cls_pred = (probs > 0.5).float()
|
289 |
acc = (cls_pred == y).sum()
|
290 |
-
predictions.
|
291 |
-
true_labels.
|
292 |
all_accs += acc
|
293 |
total += len(y)
|
294 |
pbar.set_description("acc: {:.4f}".format(acc))
|
|
|
74 |
lrs = []
|
75 |
# self.optim_params['lr_history'] = []
|
76 |
epochs_without_improvement = 0
|
77 |
+
# main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu'
|
78 |
+
main_proccess = True # change in a ddp setting
|
79 |
print(f"Starting training for {num_epochs} epochs")
|
80 |
print("is main process: ", main_proccess, flush=True)
|
81 |
global_time = time.time()
|
|
|
221 |
x = x.to(device).float()
|
222 |
fft = fft.to(device).float()
|
223 |
y = y.to(device).float()
|
224 |
+
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
225 |
+
y_pred = self.model(x_fft).squeeze()
|
226 |
loss = self.criterion(y_pred, y)
|
227 |
loss.backward()
|
228 |
self.optimizer.step()
|
|
|
231 |
# get predicted classes
|
232 |
probs = torch.sigmoid(y_pred)
|
233 |
cls_pred = (probs > 0.5).float()
|
234 |
+
acc = (cls_pred == y).sum()
|
235 |
return loss, acc, y
|
236 |
|
237 |
def eval_epoch(self, device, epoch):
|
|
|
258 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
259 |
x = x.to(device).float()
|
260 |
fft = fft.to(device).float()
|
261 |
+
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
262 |
y = y.to(device).float()
|
263 |
with torch.no_grad():
|
264 |
+
y_pred = self.model(x_fft).squeeze()
|
265 |
+
loss = self.criterion(y_pred.squeeze(), y)
|
266 |
probs = torch.sigmoid(y_pred)
|
267 |
cls_pred = (probs > 0.5).float()
|
268 |
acc = (cls_pred == y).sum()
|
|
|
282 |
x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
|
283 |
x = x.to(device).float()
|
284 |
fft = fft.to(device).float()
|
285 |
+
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
286 |
y = y.to(device).float()
|
287 |
with torch.no_grad():
|
288 |
+
y_pred = self.model(x_fft).squeeze()
|
289 |
loss = self.criterion(y_pred, y)
|
290 |
probs = torch.sigmoid(y_pred)
|
291 |
cls_pred = (probs > 0.5).float()
|
292 |
acc = (cls_pred == y).sum()
|
293 |
+
predictions.extend(cls_pred.cpu().numpy())
|
294 |
+
true_labels.extend(y.cpu().numpy())
|
295 |
all_accs += acc
|
296 |
total += len(y)
|
297 |
pbar.set_description("acc: {:.4f}".format(acc))
|