Spaces:
Sleeping
Sleeping
IlayMalinyak
commited on
Commit
·
2f54ec8
1
Parent(s):
92c056d
cnnkan
Browse files- tasks/audio.py +8 -9
- tasks/inr_database/inr_database.pt +3 -0
- tasks/models/frugal_2025-01-27/frugal_kan_2.pth +2 -2
- tasks/models/frugal_2025-01-28/frugal_kan_2.pth +3 -0
- tasks/models/frugal_2025-01-29/CNNEncoder_frugal_2.json +0 -0
- tasks/models/frugal_2025-01-29/frugal_kan_2.pth +3 -0
- tasks/run.py +55 -14
- tasks/run_inr.py +200 -0
- tasks/tasks/models/frugal_2025-01-28/CNNEncoder_frugal_2.json +420 -0
- tasks/tasks/models/frugal_2025-01-28/frugal_kan_2.pth +3 -0
- tasks/utils/config.yaml +45 -4
- tasks/utils/data.py +118 -12
- tasks/utils/data_utils.py +9 -2
- tasks/utils/dfs/test.csv +0 -0
- tasks/utils/dfs/train.csv +0 -0
- tasks/utils/dfs/val.csv +0 -0
- tasks/utils/graph_constructor.py +214 -0
- tasks/utils/inr.py +147 -0
- tasks/utils/models.py +103 -0
- tasks/utils/pooling.py +199 -0
- tasks/utils/probe_features.py +54 -0
- tasks/utils/relational_transformer.py +361 -0
- tasks/utils/train.py +423 -4
- tasks/utils/transforms.py +272 -0
- test +0 -0
tasks/audio.py
CHANGED
@@ -29,7 +29,6 @@ DESCRIPTION = "Conformer"
|
|
29 |
ROUTE = "/audio"
|
30 |
|
31 |
|
32 |
-
|
33 |
@router.post(ROUTE, tags=["Audio Task"],
|
34 |
description=DESCRIPTION)
|
35 |
async def evaluate_audio(request: AudioEvaluationRequest):
|
@@ -133,11 +132,11 @@ async def evaluate_audio(request: AudioEvaluationRequest):
|
|
133 |
|
134 |
return results
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
#
|
143 |
-
|
|
|
29 |
ROUTE = "/audio"
|
30 |
|
31 |
|
|
|
32 |
@router.post(ROUTE, tags=["Audio Task"],
|
33 |
description=DESCRIPTION)
|
34 |
async def evaluate_audio(request: AudioEvaluationRequest):
|
|
|
132 |
|
133 |
return results
|
134 |
|
135 |
+
if __name__ == "__main__":
|
136 |
+
sample_request = AudioEvaluationRequest(
|
137 |
+
dataset_name="rfcx/frugalai", # Replace with actual dataset name
|
138 |
+
test_size=0.2, # Example values
|
139 |
+
test_seed=42
|
140 |
+
)
|
141 |
+
#
|
142 |
+
asyncio.run(evaluate_audio(sample_request))
|
tasks/inr_database/inr_database.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:13ead5ed23d1fa59062f4872fd784f609575fa7dd4876ea0ef562f9f817801c1
|
3 |
+
size 50872350
|
tasks/models/frugal_2025-01-27/frugal_kan_2.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:58e353129b2750993441ea459485b150b2a45b39cbdb7e49bd1839809e4671e2
|
3 |
+
size 1363844
|
tasks/models/frugal_2025-01-28/frugal_kan_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9591abf1e617bfedd8414f7c031de3394d7e1bccc64763f7060cef0ee13fab65
|
3 |
+
size 1710980
|
tasks/models/frugal_2025-01-29/CNNEncoder_frugal_2.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tasks/models/frugal_2025-01-29/frugal_kan_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f11c67e65bf6acd1f7e2055dce4a92de0d4d1d88fb12c136eb70198c8ac6eab8
|
3 |
+
size 1710980
|
tasks/run.py
CHANGED
@@ -1,14 +1,17 @@
|
|
1 |
from torch.utils.data import DataLoader
|
2 |
from .utils.data import FFTDataset, SplitDataset
|
3 |
from datasets import load_dataset
|
4 |
-
from .utils.train import Trainer
|
5 |
-
from .utils.models import CNNKan, KanEncoder
|
6 |
from .utils.data_utils import *
|
7 |
from huggingface_hub import login
|
8 |
import yaml
|
9 |
import datetime
|
10 |
import json
|
11 |
import numpy as np
|
|
|
|
|
|
|
12 |
from collections import OrderedDict
|
13 |
|
14 |
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
@@ -19,9 +22,11 @@ data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
|
|
19 |
exp_num = data_args.exp_num
|
20 |
model_name = data_args.model_name
|
21 |
model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
|
|
|
22 |
model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
|
23 |
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
|
24 |
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
|
|
|
25 |
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
|
26 |
os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
|
27 |
|
@@ -44,26 +49,62 @@ val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_f
|
|
44 |
test_ds = FFTDataset(dataset["test"])
|
45 |
test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
|
46 |
|
47 |
-
#
|
48 |
-
#
|
49 |
-
#
|
50 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
# exit()
|
53 |
|
54 |
# model = DualEncoder(model_args, model_args_f, conformer_args)
|
55 |
# model = FasterKAN([18000,64,64,16,1])
|
56 |
model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
|
|
57 |
# model.kan.speed()
|
58 |
# model = KanEncoder(kan_args.get_dict())
|
59 |
model = model.to(local_rank)
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
67 |
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
68 |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
69 |
print(f"Number of parameters: {num_params}")
|
@@ -92,7 +133,7 @@ fit_res = trainer.fit(num_epochs=100, device=local_rank,
|
|
92 |
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
|
93 |
with open(output_filename, "w") as f:
|
94 |
json.dump(fit_res, f, indent=2)
|
95 |
-
preds, acc = trainer.predict(test_dl, local_rank)
|
96 |
print(f"Accuracy: {acc}")
|
97 |
|
98 |
|
|
|
1 |
from torch.utils.data import DataLoader
|
2 |
from .utils.data import FFTDataset, SplitDataset
|
3 |
from datasets import load_dataset
|
4 |
+
from .utils.train import Trainer, XGBoostTrainer
|
5 |
+
from .utils.models import CNNKan, KanEncoder, CNNKanFeaturesEncoder
|
6 |
from .utils.data_utils import *
|
7 |
from huggingface_hub import login
|
8 |
import yaml
|
9 |
import datetime
|
10 |
import json
|
11 |
import numpy as np
|
12 |
+
import pandas as pd
|
13 |
+
import seaborn as sns
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
from collections import OrderedDict
|
16 |
|
17 |
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
22 |
exp_num = data_args.exp_num
|
23 |
model_name = data_args.model_name
|
24 |
model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
|
25 |
+
mlp_args = Container(**yaml.safe_load(open(args_dir, 'r'))['MLP'])
|
26 |
model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
|
27 |
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
|
28 |
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
|
29 |
+
boost_args = Container(**yaml.safe_load(open(args_dir, 'r'))['XGBoost'])
|
30 |
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
|
31 |
os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
|
32 |
|
|
|
49 |
test_ds = FFTDataset(dataset["test"])
|
50 |
test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
|
51 |
|
52 |
+
# data = []
|
53 |
+
#
|
54 |
+
# # Iterate over the dataset
|
55 |
+
# for i, batch in enumerate(train_ds):
|
56 |
+
# label = batch['label']
|
57 |
+
# features = batch['audio']['features']
|
58 |
+
#
|
59 |
+
# # Flatten the nested dictionary structure
|
60 |
+
# feature_dict = {'label': label}
|
61 |
+
# for k, v in features.items():
|
62 |
+
# if isinstance(v, dict):
|
63 |
+
# for sub_k, sub_v in v.items():
|
64 |
+
# feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean)
|
65 |
+
# else:
|
66 |
+
# print(k, v.shape) # Aggregate (e.g., mean)
|
67 |
+
#
|
68 |
+
# data.append(feature_dict)
|
69 |
+
# print(i)
|
70 |
+
#
|
71 |
+
# if i > 1000: # Limit to 10 iterations
|
72 |
# break
|
73 |
+
#
|
74 |
+
# # Convert to DataFrame
|
75 |
+
# df = pd.DataFrame(data)
|
76 |
+
|
77 |
+
# Plot distributions colored by label
|
78 |
+
# plt.figure()
|
79 |
+
# for col in df.columns:
|
80 |
+
# if col != 'label':
|
81 |
+
# sns.kdeplot(df, x=col, hue='label', fill=True, alpha=0.5)
|
82 |
+
# plt.title(f'Distribution of {col}')
|
83 |
+
# plt.show()
|
84 |
+
# exit()
|
85 |
+
|
86 |
+
# trainer = XGBoostTrainer(boost_args.get_dict(), train_ds, val_ds, test_ds)
|
87 |
+
# res = trainer.fit()
|
88 |
+
# trainer.predict()
|
89 |
+
# trainer.plot_results(res)
|
90 |
# exit()
|
91 |
|
92 |
# model = DualEncoder(model_args, model_args_f, conformer_args)
|
93 |
# model = FasterKAN([18000,64,64,16,1])
|
94 |
model = CNNKan(model_args, conformer_args, kan_args.get_dict())
|
95 |
+
# model = CNNKanFeaturesEncoder(model_args, mlp_args, kan_args.get_dict())
|
96 |
# model.kan.speed()
|
97 |
# model = KanEncoder(kan_args.get_dict())
|
98 |
model = model.to(local_rank)
|
99 |
+
|
100 |
+
# state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
|
101 |
+
# new_state_dict = OrderedDict()
|
102 |
+
# for key, value in state_dict.items():
|
103 |
+
# if key.startswith('module.'):
|
104 |
+
# key = key[7:]
|
105 |
+
# new_state_dict[key] = value
|
106 |
+
# missing, unexpected = model.load_state_dict(new_state_dict)
|
107 |
+
|
108 |
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
109 |
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
110 |
print(f"Number of parameters: {num_params}")
|
|
|
133 |
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
|
134 |
with open(output_filename, "w") as f:
|
135 |
json.dump(fit_res, f, indent=2)
|
136 |
+
preds, tru, acc = trainer.predict(test_dl, local_rank)
|
137 |
print(f"Accuracy: {acc}")
|
138 |
|
139 |
|
tasks/run_inr.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from utils.data import FFTDataset, SplitDataset, AudioINRDataset
|
4 |
+
from datasets import load_dataset
|
5 |
+
from utils.train import Trainer, INRTrainer
|
6 |
+
from utils.models import MultiGraph, ImplicitEncoder
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
|
9 |
+
# from .utils.models import CNNKan, KanEncoder
|
10 |
+
from utils.inr import INR
|
11 |
+
from utils.data_utils import *
|
12 |
+
from huggingface_hub import login
|
13 |
+
import yaml
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
import numpy as np
|
17 |
+
from tqdm import tqdm
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
from scipy.signal import savgol_filter as savgol
|
20 |
+
from utils.kan import FasterKAN
|
21 |
+
from utils.relational_transformer import RelationalTransformer
|
22 |
+
from collections import OrderedDict
|
23 |
+
def plot_results(dims, i, data, losses, pred_values):
|
24 |
+
data = savgol(data.cpu().detach().numpy(), window_length=250, polyorder=1)
|
25 |
+
pred_values = pred_values.transpose(-1, -2).unflatten(-1, data.shape[-2:]).squeeze(0).cpu().detach().numpy()
|
26 |
+
pred_values = (pred_values - np.min(pred_values)) / (np.max(pred_values) - np.min(pred_values))
|
27 |
+
data = (data - np.min(data)) / (np.max(data) - np.min(data))
|
28 |
+
plt.plot(data.squeeze())
|
29 |
+
plt.plot(pred_values.squeeze())
|
30 |
+
# axes[0].set_title('Original')
|
31 |
+
# axes[1].set_title('Reconstruction')
|
32 |
+
plt.show()
|
33 |
+
# plt.plot(np.arange(len(losses)), losses)
|
34 |
+
# plt.xlabel('Iteration')
|
35 |
+
# plt.ylabel('Reconstruction MSE Error')
|
36 |
+
# plt.show()
|
37 |
+
|
38 |
+
# local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
39 |
+
current_date = datetime.date.today().strftime("%Y-%m-%d")
|
40 |
+
datetime_dir = f"frugal_{current_date}"
|
41 |
+
args_dir = 'utils/config.yaml'
|
42 |
+
data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
|
43 |
+
exp_num = data_args.exp_num
|
44 |
+
model_name = data_args.model_name
|
45 |
+
rt_args = Container(**yaml.safe_load(open(args_dir, 'r'))['RelationalTransformer'])
|
46 |
+
cnn_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
|
47 |
+
conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
|
48 |
+
kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN_INR'])
|
49 |
+
inr_args = Container(**yaml.safe_load(open(args_dir, 'r'))['INR'])
|
50 |
+
if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
|
51 |
+
os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
|
52 |
+
|
53 |
+
with open("../../logs/token.txt", "r") as f:
|
54 |
+
api_key = f.read()
|
55 |
+
|
56 |
+
# local_rank, world_size, gpus_per_node = setup()
|
57 |
+
local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
58 |
+
login(api_key)
|
59 |
+
dataset = load_dataset("rfcx/frugalai", streaming=True)
|
60 |
+
|
61 |
+
train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True)
|
62 |
+
|
63 |
+
train_dl = DataLoader(train_ds, batch_size=data_args.batch_size)
|
64 |
+
|
65 |
+
val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False)
|
66 |
+
|
67 |
+
val_dl = DataLoader(val_ds, batch_size=data_args.batch_size)
|
68 |
+
|
69 |
+
test_ds = AudioINRDataset(FFTDataset(dataset["test"]))
|
70 |
+
test_dl = DataLoader(test_ds, batch_size=data_args.batch_size)
|
71 |
+
|
72 |
+
# for i, batch in enumerate(train_ds):
|
73 |
+
# fft_phase, fft_mag, audio = batch['audio']['fft_phase'], batch['audio']['fft_mag'], batch['audio']['array']
|
74 |
+
# label = batch['label']
|
75 |
+
# fig, axes = plt.subplots(nrows=1, ncols=3)
|
76 |
+
# axes = axes.flatten()
|
77 |
+
# axes[0].plot(fft_phase)
|
78 |
+
# axes[1].plot(fft_mag)
|
79 |
+
# axes[2].plot(audio)
|
80 |
+
# fig.suptitle(label)
|
81 |
+
# plt.tight_layout()
|
82 |
+
# plt.show()
|
83 |
+
# if i > 20:
|
84 |
+
# break
|
85 |
+
# model = DualEncoder(model_args, model_args_f, conformer_args)
|
86 |
+
# model = FasterKAN([18000,64,64,16,1])
|
87 |
+
# model = INR(in_features=1)
|
88 |
+
# model.kan.speed()
|
89 |
+
# model = KanEncoder(kan_args.get_dict())
|
90 |
+
# model = model.to(local_rank)
|
91 |
+
|
92 |
+
# state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
|
93 |
+
# new_state_dict = OrderedDict()
|
94 |
+
# for key, value in state_dict.items():
|
95 |
+
# if key.startswith('module.'):
|
96 |
+
# key = key[7:]
|
97 |
+
# new_state_dict[key] = value
|
98 |
+
# missing, unexpected = model.load_state_dict(new_state_dict)
|
99 |
+
|
100 |
+
# model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
101 |
+
# num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
102 |
+
# print(f"Number of parameters: {num_params}")
|
103 |
+
#
|
104 |
+
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
|
105 |
+
# total_steps = int(data_args.num_epochs) * 1000
|
106 |
+
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
|
107 |
+
# T_max=total_steps,
|
108 |
+
# eta_min=float((5e-4) / 10))
|
109 |
+
|
110 |
+
loss_fn = torch.nn.BCEWithLogitsLoss()
|
111 |
+
inr_criterion = torch.nn.MSELoss()
|
112 |
+
|
113 |
+
# for i, batch in enumerate(train_ds):
|
114 |
+
# coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array']
|
115 |
+
# coords = coords.to(local_rank)
|
116 |
+
# fft = fft.to(local_rank)
|
117 |
+
# audio = audio.to(local_rank)
|
118 |
+
# values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1)
|
119 |
+
# # model = INR(hidden_features=128, n_layers=3,
|
120 |
+
# # in_features=1,
|
121 |
+
# # out_features=1).to(local_rank)
|
122 |
+
# model = FasterKAN(**kan_args.get_dict()).to(local_rank)
|
123 |
+
# optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
|
124 |
+
# pbar = tqdm(range(200))
|
125 |
+
# losses = []
|
126 |
+
# print(coords.shape)
|
127 |
+
# for t in pbar:
|
128 |
+
# optimizer.zero_grad()
|
129 |
+
# pred_values = model(coords.to(local_rank)).float()
|
130 |
+
# loss = inr_criterion(pred_values, values)
|
131 |
+
# loss.backward()
|
132 |
+
# optimizer.step()
|
133 |
+
# pbar.set_description(f'loss: {loss.item()}')
|
134 |
+
# losses.append(loss.item())
|
135 |
+
# state_dict = model.state_dict()
|
136 |
+
# torch.save(state_dict, 'test')
|
137 |
+
# # print(f'Sample {i+offset} label {label} saved in {inr_path}')
|
138 |
+
# plot_results(1, i, fft, losses, pred_values)
|
139 |
+
# #
|
140 |
+
# exit()
|
141 |
+
|
142 |
+
|
143 |
+
# missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
|
144 |
+
# print(f"Missing keys: {missing}")
|
145 |
+
# print(f"Unexpected keys: {unexpected}")
|
146 |
+
layer_layout = [inr_args.in_features] + [inr_args.hidden_features for _ in range(inr_args.n_layers)] + [inr_args.out_features]
|
147 |
+
|
148 |
+
graph_constructor = OmegaConf.create(
|
149 |
+
{
|
150 |
+
"_target_": "utils.graph_constructor.GraphConstructor",
|
151 |
+
"_recursive_": False,
|
152 |
+
"_convert_": "all",
|
153 |
+
"d_in": 1,
|
154 |
+
"d_edge_in": 1,
|
155 |
+
"zero_out_bias": False,
|
156 |
+
"zero_out_weights": False,
|
157 |
+
"sin_emb": True,
|
158 |
+
"sin_emb_dim": rt_args.d_node,
|
159 |
+
"use_pos_embed": False,
|
160 |
+
"input_layers": 1,
|
161 |
+
"inp_factor": 1,
|
162 |
+
"num_probe_features": 0,
|
163 |
+
"inr_model": None,
|
164 |
+
"stats": None,
|
165 |
+
"sparsify": False,
|
166 |
+
'sym_edges': False,
|
167 |
+
}
|
168 |
+
)
|
169 |
+
|
170 |
+
rt_model = RelationalTransformer(layer_layout=layer_layout, graph_constructor=graph_constructor,
|
171 |
+
**rt_args.get_dict()).to(local_rank)
|
172 |
+
rt_model.proj_out= torch.nn.Identity()
|
173 |
+
multi_graph = MultiGraph(rt_model, cnn_args)
|
174 |
+
implicit_net = INR(**inr_args.get_dict())
|
175 |
+
model = ImplicitEncoder(implicit_net, multi_graph).to(local_rank)
|
176 |
+
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
177 |
+
print(f"Number of parameters: {num_parameters}")
|
178 |
+
optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
|
179 |
+
trainer = Trainer(model=model, optimizer=optimizer,
|
180 |
+
criterion=loss_fn, output_dim=1, scaler=None,
|
181 |
+
scheduler=None, train_dataloader=train_dl,
|
182 |
+
val_dataloader=val_dl, device=local_rank,
|
183 |
+
exp_num=datetime_dir, log_path=data_args.log_dir,
|
184 |
+
range_update=None,
|
185 |
+
accumulation_step=1, max_iter=100,
|
186 |
+
exp_name=f"frugal_kan_{exp_num}")
|
187 |
+
fit_res = trainer.fit(num_epochs=100, device=local_rank,
|
188 |
+
early_stopping=10, only_p=False, best='loss', conf=True)
|
189 |
+
output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
|
190 |
+
with open(output_filename, "w") as f:
|
191 |
+
json.dump(fit_res, f, indent=2)
|
192 |
+
preds, acc = trainer.predict(test_dl, local_rank)
|
193 |
+
print(f"Accuracy: {acc}")
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
|
tasks/tasks/models/frugal_2025-01-28/CNNEncoder_frugal_2.json
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"num_epochs": 100,
|
3 |
+
"train_loss": [
|
4 |
+
0,
|
5 |
+
0,
|
6 |
+
0,
|
7 |
+
0,
|
8 |
+
0,
|
9 |
+
0,
|
10 |
+
0,
|
11 |
+
0,
|
12 |
+
0,
|
13 |
+
0,
|
14 |
+
0,
|
15 |
+
0,
|
16 |
+
0,
|
17 |
+
0,
|
18 |
+
0,
|
19 |
+
0,
|
20 |
+
0,
|
21 |
+
0,
|
22 |
+
0,
|
23 |
+
0,
|
24 |
+
0,
|
25 |
+
0,
|
26 |
+
0,
|
27 |
+
0,
|
28 |
+
0,
|
29 |
+
0,
|
30 |
+
0,
|
31 |
+
0,
|
32 |
+
0,
|
33 |
+
0,
|
34 |
+
0,
|
35 |
+
0,
|
36 |
+
0,
|
37 |
+
0,
|
38 |
+
0,
|
39 |
+
0,
|
40 |
+
0,
|
41 |
+
0,
|
42 |
+
0,
|
43 |
+
0,
|
44 |
+
0,
|
45 |
+
0,
|
46 |
+
0,
|
47 |
+
0,
|
48 |
+
0,
|
49 |
+
0,
|
50 |
+
0,
|
51 |
+
0,
|
52 |
+
0,
|
53 |
+
0,
|
54 |
+
0,
|
55 |
+
0,
|
56 |
+
0,
|
57 |
+
0,
|
58 |
+
0,
|
59 |
+
0,
|
60 |
+
0,
|
61 |
+
0,
|
62 |
+
0,
|
63 |
+
0,
|
64 |
+
0,
|
65 |
+
0,
|
66 |
+
0,
|
67 |
+
0,
|
68 |
+
0,
|
69 |
+
0,
|
70 |
+
0,
|
71 |
+
0,
|
72 |
+
0,
|
73 |
+
0,
|
74 |
+
0,
|
75 |
+
0,
|
76 |
+
0,
|
77 |
+
0,
|
78 |
+
0,
|
79 |
+
0,
|
80 |
+
0,
|
81 |
+
0,
|
82 |
+
0,
|
83 |
+
0,
|
84 |
+
0,
|
85 |
+
0,
|
86 |
+
0,
|
87 |
+
0,
|
88 |
+
0,
|
89 |
+
0,
|
90 |
+
0,
|
91 |
+
0,
|
92 |
+
0,
|
93 |
+
0,
|
94 |
+
0,
|
95 |
+
0,
|
96 |
+
0,
|
97 |
+
0,
|
98 |
+
0,
|
99 |
+
0,
|
100 |
+
0,
|
101 |
+
0,
|
102 |
+
0,
|
103 |
+
0,
|
104 |
+
0,
|
105 |
+
0,
|
106 |
+
0,
|
107 |
+
0,
|
108 |
+
0,
|
109 |
+
0,
|
110 |
+
0,
|
111 |
+
0,
|
112 |
+
0,
|
113 |
+
0,
|
114 |
+
0,
|
115 |
+
0,
|
116 |
+
0,
|
117 |
+
0,
|
118 |
+
0,
|
119 |
+
0,
|
120 |
+
0,
|
121 |
+
0,
|
122 |
+
0,
|
123 |
+
0,
|
124 |
+
0,
|
125 |
+
0,
|
126 |
+
0,
|
127 |
+
0,
|
128 |
+
0,
|
129 |
+
0,
|
130 |
+
0,
|
131 |
+
0,
|
132 |
+
0,
|
133 |
+
0,
|
134 |
+
0,
|
135 |
+
0,
|
136 |
+
0,
|
137 |
+
0,
|
138 |
+
0,
|
139 |
+
0,
|
140 |
+
0,
|
141 |
+
0,
|
142 |
+
0,
|
143 |
+
0,
|
144 |
+
0,
|
145 |
+
0,
|
146 |
+
0,
|
147 |
+
0,
|
148 |
+
0,
|
149 |
+
0,
|
150 |
+
0,
|
151 |
+
0,
|
152 |
+
0,
|
153 |
+
0,
|
154 |
+
0,
|
155 |
+
0,
|
156 |
+
0,
|
157 |
+
0,
|
158 |
+
0,
|
159 |
+
0,
|
160 |
+
0,
|
161 |
+
0,
|
162 |
+
0,
|
163 |
+
0,
|
164 |
+
0,
|
165 |
+
0,
|
166 |
+
0,
|
167 |
+
0,
|
168 |
+
0,
|
169 |
+
0,
|
170 |
+
0,
|
171 |
+
0,
|
172 |
+
0,
|
173 |
+
0,
|
174 |
+
0,
|
175 |
+
0,
|
176 |
+
0,
|
177 |
+
0,
|
178 |
+
0,
|
179 |
+
0,
|
180 |
+
0,
|
181 |
+
0,
|
182 |
+
0,
|
183 |
+
0,
|
184 |
+
0,
|
185 |
+
0,
|
186 |
+
0,
|
187 |
+
0,
|
188 |
+
0,
|
189 |
+
0,
|
190 |
+
0
|
191 |
+
],
|
192 |
+
"val_loss": [
|
193 |
+
0.7067973613739014,
|
194 |
+
0.6959081888198853,
|
195 |
+
0.7170985341072083,
|
196 |
+
0.6722545623779297,
|
197 |
+
0.7166763544082642,
|
198 |
+
0.7339211702346802,
|
199 |
+
0.7830350399017334,
|
200 |
+
0.7616254091262817,
|
201 |
+
0.7465198040008545,
|
202 |
+
0.7474430799484253,
|
203 |
+
0.7112622261047363,
|
204 |
+
0.8131649494171143,
|
205 |
+
0.7090165019035339,
|
206 |
+
0.697528600692749,
|
207 |
+
0.7792525291442871,
|
208 |
+
0.700302243232727,
|
209 |
+
0.7315454483032227,
|
210 |
+
0.7067973613739014,
|
211 |
+
0.6959081888198853,
|
212 |
+
0.7170985341072083,
|
213 |
+
0.6722545623779297,
|
214 |
+
0.7166763544082642,
|
215 |
+
0.7339211702346802,
|
216 |
+
0.7830350399017334,
|
217 |
+
0.7616254091262817,
|
218 |
+
0.7465198040008545,
|
219 |
+
0.7474430799484253,
|
220 |
+
0.7112622261047363,
|
221 |
+
0.8131649494171143,
|
222 |
+
0.7090165019035339,
|
223 |
+
0.697528600692749,
|
224 |
+
0.7792525291442871,
|
225 |
+
0.700302243232727,
|
226 |
+
0.7315454483032227,
|
227 |
+
0.7067973613739014,
|
228 |
+
0.6959081888198853,
|
229 |
+
0.7170985341072083,
|
230 |
+
0.6722545623779297,
|
231 |
+
0.7166763544082642,
|
232 |
+
0.7339211702346802,
|
233 |
+
0.7830350399017334,
|
234 |
+
0.7616254091262817,
|
235 |
+
0.7465198040008545,
|
236 |
+
0.7474430799484253,
|
237 |
+
0.7112622261047363,
|
238 |
+
0.8131649494171143,
|
239 |
+
0.7090165019035339,
|
240 |
+
0.697528600692749,
|
241 |
+
0.7792525291442871,
|
242 |
+
0.700302243232727,
|
243 |
+
0.7315454483032227,
|
244 |
+
0.7067973613739014,
|
245 |
+
0.6959081888198853,
|
246 |
+
0.7170985341072083,
|
247 |
+
0.6722545623779297,
|
248 |
+
0.7166763544082642,
|
249 |
+
0.7339211702346802,
|
250 |
+
0.7830350399017334,
|
251 |
+
0.7616254091262817,
|
252 |
+
0.7465198040008545,
|
253 |
+
0.7474430799484253,
|
254 |
+
0.7112622261047363,
|
255 |
+
0.8131649494171143,
|
256 |
+
0.7090165019035339,
|
257 |
+
0.697528600692749,
|
258 |
+
0.7792525291442871,
|
259 |
+
0.700302243232727,
|
260 |
+
0.7315454483032227,
|
261 |
+
0.7067973613739014,
|
262 |
+
0.6959081888198853,
|
263 |
+
0.7170985341072083,
|
264 |
+
0.6722545623779297,
|
265 |
+
0.7166763544082642,
|
266 |
+
0.7339211702346802,
|
267 |
+
0.7830350399017334,
|
268 |
+
0.7616254091262817,
|
269 |
+
0.7465198040008545,
|
270 |
+
0.7474430799484253,
|
271 |
+
0.7112622261047363,
|
272 |
+
0.8131649494171143,
|
273 |
+
0.7090165019035339,
|
274 |
+
0.697528600692749,
|
275 |
+
0.7792525291442871,
|
276 |
+
0.700302243232727,
|
277 |
+
0.7315454483032227,
|
278 |
+
0.7067973613739014,
|
279 |
+
0.6959081888198853,
|
280 |
+
0.7170985341072083,
|
281 |
+
0.6722545623779297,
|
282 |
+
0.7166763544082642,
|
283 |
+
0.7339211702346802,
|
284 |
+
0.7830350399017334,
|
285 |
+
0.7616254091262817,
|
286 |
+
0.7465198040008545,
|
287 |
+
0.7474430799484253,
|
288 |
+
0.7112622261047363,
|
289 |
+
0.8131649494171143,
|
290 |
+
0.7090165019035339,
|
291 |
+
0.697528600692749,
|
292 |
+
0.7792525291442871,
|
293 |
+
0.700302243232727,
|
294 |
+
0.7315454483032227,
|
295 |
+
0.7067973613739014,
|
296 |
+
0.6959081888198853,
|
297 |
+
0.7170985341072083,
|
298 |
+
0.6722545623779297,
|
299 |
+
0.7166763544082642,
|
300 |
+
0.7339211702346802,
|
301 |
+
0.7830350399017334,
|
302 |
+
0.7616254091262817,
|
303 |
+
0.7465198040008545,
|
304 |
+
0.7474430799484253,
|
305 |
+
0.7112622261047363,
|
306 |
+
0.8131649494171143,
|
307 |
+
0.7090165019035339,
|
308 |
+
0.697528600692749,
|
309 |
+
0.7792525291442871,
|
310 |
+
0.700302243232727,
|
311 |
+
0.7315454483032227,
|
312 |
+
0.7067973613739014,
|
313 |
+
0.6959081888198853,
|
314 |
+
0.7170985341072083,
|
315 |
+
0.6722545623779297,
|
316 |
+
0.7166763544082642,
|
317 |
+
0.7339211702346802,
|
318 |
+
0.7830350399017334,
|
319 |
+
0.7616254091262817,
|
320 |
+
0.7465198040008545,
|
321 |
+
0.7474430799484253,
|
322 |
+
0.7112622261047363,
|
323 |
+
0.8131649494171143,
|
324 |
+
0.7090165019035339,
|
325 |
+
0.697528600692749,
|
326 |
+
0.7792525291442871,
|
327 |
+
0.700302243232727,
|
328 |
+
0.7315454483032227,
|
329 |
+
0.7067973613739014,
|
330 |
+
0.6959081888198853,
|
331 |
+
0.7170985341072083,
|
332 |
+
0.6722545623779297,
|
333 |
+
0.7166763544082642,
|
334 |
+
0.7339211702346802,
|
335 |
+
0.7830350399017334,
|
336 |
+
0.7616254091262817,
|
337 |
+
0.7465198040008545,
|
338 |
+
0.7474430799484253,
|
339 |
+
0.7112622261047363,
|
340 |
+
0.8131649494171143,
|
341 |
+
0.7090165019035339,
|
342 |
+
0.697528600692749,
|
343 |
+
0.7792525291442871,
|
344 |
+
0.700302243232727,
|
345 |
+
0.7315454483032227,
|
346 |
+
0.7067973613739014,
|
347 |
+
0.6959081888198853,
|
348 |
+
0.7170985341072083,
|
349 |
+
0.6722545623779297,
|
350 |
+
0.7166763544082642,
|
351 |
+
0.7339211702346802,
|
352 |
+
0.7830350399017334,
|
353 |
+
0.7616254091262817,
|
354 |
+
0.7465198040008545,
|
355 |
+
0.7474430799484253,
|
356 |
+
0.7112622261047363,
|
357 |
+
0.8131649494171143,
|
358 |
+
0.7090165019035339,
|
359 |
+
0.697528600692749,
|
360 |
+
0.7792525291442871,
|
361 |
+
0.700302243232727,
|
362 |
+
0.7315454483032227,
|
363 |
+
0.7067973613739014,
|
364 |
+
0.6959081888198853,
|
365 |
+
0.7170985341072083,
|
366 |
+
0.6722545623779297,
|
367 |
+
0.7166763544082642,
|
368 |
+
0.7339211702346802,
|
369 |
+
0.7830350399017334,
|
370 |
+
0.7616254091262817,
|
371 |
+
0.7465198040008545,
|
372 |
+
0.7474430799484253,
|
373 |
+
0.7112622261047363,
|
374 |
+
0.8131649494171143,
|
375 |
+
0.7090165019035339,
|
376 |
+
0.697528600692749,
|
377 |
+
0.7792525291442871,
|
378 |
+
0.700302243232727,
|
379 |
+
0.7315454483032227
|
380 |
+
],
|
381 |
+
"train_acc": [
|
382 |
+
0.0,
|
383 |
+
0.0,
|
384 |
+
0.0,
|
385 |
+
0.0,
|
386 |
+
0.0,
|
387 |
+
0.0,
|
388 |
+
0.0,
|
389 |
+
0.0,
|
390 |
+
0.0,
|
391 |
+
0.0,
|
392 |
+
0.0
|
393 |
+
],
|
394 |
+
"val_acc": [
|
395 |
+
0.0,
|
396 |
+
0.0,
|
397 |
+
0.0,
|
398 |
+
0.0,
|
399 |
+
0.0,
|
400 |
+
0.0,
|
401 |
+
0.0,
|
402 |
+
0.0,
|
403 |
+
0.0,
|
404 |
+
0.0,
|
405 |
+
0.0
|
406 |
+
],
|
407 |
+
"lrs": [
|
408 |
+
0.001,
|
409 |
+
0.001,
|
410 |
+
0.001,
|
411 |
+
0.001,
|
412 |
+
0.001,
|
413 |
+
0.001,
|
414 |
+
0.001,
|
415 |
+
0.001,
|
416 |
+
0.001,
|
417 |
+
0.001,
|
418 |
+
0.001
|
419 |
+
]
|
420 |
+
}
|
tasks/tasks/models/frugal_2025-01-28/frugal_kan_2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b22195a96ca5dd9cf4d20bb1886c6d22327d42ea3fbe4b9d86b9a281a132024
|
3 |
+
size 820491
|
tasks/utils/config.yaml
CHANGED
@@ -12,7 +12,7 @@ Data:
|
|
12 |
max_days_lc: 270
|
13 |
lc_freq: 0.0208
|
14 |
create_umap: True
|
15 |
-
checkpoint_path: 'tasks/models/frugal_2025-01-
|
16 |
|
17 |
CNNEncoder:
|
18 |
# Model
|
@@ -27,9 +27,14 @@ CNNEncoder:
|
|
27 |
load_checkpoint: False
|
28 |
checkpoint_num: 1
|
29 |
activation: "silu"
|
30 |
-
sine_w0:
|
31 |
avg_output: False
|
32 |
|
|
|
|
|
|
|
|
|
|
|
33 |
KAN:
|
34 |
layers_hidden: [1125,32,8,1]
|
35 |
grid_min: -1.2
|
@@ -37,9 +42,16 @@ KAN:
|
|
37 |
num_grids: 8
|
38 |
exponent: 2
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
CNNEncoder_f:
|
41 |
# Model
|
42 |
-
in_channels:
|
43 |
num_layers: 4
|
44 |
stride: 1
|
45 |
encoder_dims: [32,64,128]
|
@@ -64,10 +76,39 @@ Conformer:
|
|
64 |
dropout_p: 0.2
|
65 |
norm: "postnorm"
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
Optimization:
|
69 |
# Optimization
|
70 |
max_lr: 1e-5
|
71 |
weight_decay: 5e-6
|
72 |
warmup_pct: 0.3
|
73 |
-
steps_per_epoch: 3500
|
|
|
12 |
max_days_lc: 270
|
13 |
lc_freq: 0.0208
|
14 |
create_umap: True
|
15 |
+
checkpoint_path: 'tasks/models/frugal_2025-01-29/frugal_kan_2.pth'
|
16 |
|
17 |
CNNEncoder:
|
18 |
# Model
|
|
|
27 |
load_checkpoint: False
|
28 |
checkpoint_num: 1
|
29 |
activation: "silu"
|
30 |
+
sine_w0: 30.0
|
31 |
avg_output: False
|
32 |
|
33 |
+
MLP:
|
34 |
+
input_dim: 6
|
35 |
+
hidden_dims: [16,32]
|
36 |
+
dropout: 0.2
|
37 |
+
|
38 |
KAN:
|
39 |
layers_hidden: [1125,32,8,1]
|
40 |
grid_min: -1.2
|
|
|
42 |
num_grids: 8
|
43 |
exponent: 2
|
44 |
|
45 |
+
KAN_INR:
|
46 |
+
layers_hidden: [1,1024,128,128,1]
|
47 |
+
grid_min: -1.2
|
48 |
+
grid_max: 1.2
|
49 |
+
num_grids: 8
|
50 |
+
exponent: 2
|
51 |
+
|
52 |
CNNEncoder_f:
|
53 |
# Model
|
54 |
+
in_channels: 32
|
55 |
num_layers: 4
|
56 |
stride: 1
|
57 |
encoder_dims: [32,64,128]
|
|
|
76 |
dropout_p: 0.2
|
77 |
norm: "postnorm"
|
78 |
|
79 |
+
RelationalTransformer:
|
80 |
+
d_node: 32
|
81 |
+
d_edge: 32
|
82 |
+
d_attn_hid: 16
|
83 |
+
d_node_hid: 16
|
84 |
+
d_edge_hid: 16
|
85 |
+
d_out_hid: 16
|
86 |
+
d_out: 1
|
87 |
+
n_layers: 4
|
88 |
+
n_heads: 4
|
89 |
+
dropout: 0.1
|
90 |
+
|
91 |
+
|
92 |
+
INR:
|
93 |
+
in_features : 2
|
94 |
+
n_layers : 2
|
95 |
+
hidden_features : 64
|
96 |
+
out_features : 32
|
97 |
+
|
98 |
+
XGBoost:
|
99 |
+
objective : 'binary:logistic'
|
100 |
+
eval_metric : 'logloss'
|
101 |
+
use_label_encoder : False
|
102 |
+
n_estimators : 500
|
103 |
+
learning_rate : 0.1
|
104 |
+
max_depth : 5
|
105 |
+
subsample : 0.8
|
106 |
+
colsample_bytree : 0.8
|
107 |
+
random_state : 42
|
108 |
|
109 |
Optimization:
|
110 |
# Optimization
|
111 |
max_lr: 1e-5
|
112 |
weight_decay: 5e-6
|
113 |
warmup_pct: 0.3
|
114 |
+
steps_per_epoch: 3500
|
tasks/utils/data.py
CHANGED
@@ -1,11 +1,38 @@
|
|
1 |
import torch
|
2 |
from torch.utils.data import IterableDataset
|
3 |
-
from torch.fft import fft
|
4 |
import torch.nn.functional as F
|
5 |
from itertools import tee
|
6 |
import random
|
7 |
import torchaudio.transforms as T
|
|
|
|
|
|
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
class SplitDataset(IterableDataset):
|
11 |
def __init__(self, dataset, is_train=True, train_ratio=0.8):
|
@@ -28,22 +55,101 @@ class FFTDataset(IterableDataset):
|
|
28 |
def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
|
29 |
self.dataset = original_dataset
|
30 |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
|
|
31 |
self.max_len = max_len
|
32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def __iter__(self):
|
34 |
for item in self.dataset:
|
35 |
-
#
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
# continue
|
41 |
pad_len = self.max_len - len(audio_data)
|
42 |
audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
|
43 |
audio_data = self.resampler(audio_data)
|
|
|
|
|
44 |
fft_data = fft(audio_data)
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
yield item
|
|
|
1 |
import torch
|
2 |
from torch.utils.data import IterableDataset
|
3 |
+
from torch.fft import fft, fftshift
|
4 |
import torch.nn.functional as F
|
5 |
from itertools import tee
|
6 |
import random
|
7 |
import torchaudio.transforms as T
|
8 |
+
import hashlib
|
9 |
+
from typing import NamedTuple, Tuple, Union
|
10 |
+
from .transforms import compute_all_features
|
11 |
|
12 |
+
from scipy.signal import savgol_filter as savgol
|
13 |
+
|
14 |
+
|
15 |
+
class WeightsBatch(NamedTuple):
|
16 |
+
weights: Tuple
|
17 |
+
biases: Tuple
|
18 |
+
label: Union[torch.Tensor, int]
|
19 |
+
|
20 |
+
def _assert_same_len(self):
|
21 |
+
assert len(set([len(t) for t in self])) == 1
|
22 |
+
|
23 |
+
def as_dict(self):
|
24 |
+
return self._asdict()
|
25 |
+
|
26 |
+
def to(self, device):
|
27 |
+
"""move batch to device"""
|
28 |
+
return self.__class__(
|
29 |
+
weights=tuple(w.to(device) for w in self.weights),
|
30 |
+
biases=tuple(w.to(device) for w in self.biases),
|
31 |
+
label=self.label.to(device),
|
32 |
+
)
|
33 |
+
|
34 |
+
def __len__(self):
|
35 |
+
return len(self.weights[0])
|
36 |
|
37 |
class SplitDataset(IterableDataset):
|
38 |
def __init__(self, dataset, is_train=True, train_ratio=0.8):
|
|
|
55 |
def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
|
56 |
self.dataset = original_dataset
|
57 |
self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
|
58 |
+
self.target_sample_rate = target_sample_rate
|
59 |
self.max_len = max_len
|
60 |
+
|
61 |
+
|
62 |
+
def normalize_audio(self, audio):
|
63 |
+
"""Normalize audio to [0, 1] range"""
|
64 |
+
audio_min = audio.min()
|
65 |
+
audio_max = audio.max()
|
66 |
+
audio = (audio - audio_min) / (audio_max - audio_min)
|
67 |
+
return audio
|
68 |
+
|
69 |
+
def generate_unique_id(self, array):
|
70 |
+
# Convert the array to bytes
|
71 |
+
array_bytes = array.tobytes()
|
72 |
+
# Hash the bytes using SHA256
|
73 |
+
hash_object = hashlib.sha256(array_bytes)
|
74 |
+
# Return the hexadecimal representation of the hash
|
75 |
+
return hash_object.hexdigest()
|
76 |
+
|
77 |
def __iter__(self):
|
78 |
for item in self.dataset:
|
79 |
+
# audio_data = savgol(item['audio']['array'], 500, polyorder=1)
|
80 |
+
audio_data = item['audio']['array']
|
81 |
+
# item['id'] = self.generate_unique_id(audio_data)
|
82 |
+
audio_data = torch.tensor(audio_data).float()
|
83 |
+
|
|
|
84 |
pad_len = self.max_len - len(audio_data)
|
85 |
audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
|
86 |
audio_data = self.resampler(audio_data)
|
87 |
+
|
88 |
+
audio_data = self.normalize_audio(audio_data)
|
89 |
fft_data = fft(audio_data)
|
90 |
+
magnitude = torch.abs(fft_data)
|
91 |
+
phase = torch.angle(fft_data)
|
92 |
+
# features = compute_all_features(audio_data, sample_rate=self.target_sample_rate)
|
93 |
+
# features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()])
|
94 |
+
magnitude_centered = fftshift(magnitude)
|
95 |
+
phase_centered = fftshift(phase)
|
96 |
+
# cwt = features['cwt_power']
|
97 |
+
|
98 |
+
# Optionally, remove the DC component
|
99 |
+
magnitude_centered[len(magnitude_centered) // 2] = 0 # Set DC component to zero
|
100 |
+
|
101 |
+
item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0)
|
102 |
+
item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0)
|
103 |
+
# item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0)
|
104 |
+
item['audio']['array'] = torch.nan_to_num(audio_data, 0)
|
105 |
+
# item['audio']['features'] = features
|
106 |
+
# item['audio']['features_arr'] = torch.nan_to_num(features_arr, 0)
|
107 |
+
yield item
|
108 |
+
|
109 |
+
|
110 |
+
class AudioINRDataset(IterableDataset):
|
111 |
+
def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True):
|
112 |
+
"""
|
113 |
+
Convert audio data into coordinate-value pairs for INR training.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
original_dataset: Original audio dataset
|
117 |
+
max_len: Maximum length of audio to process
|
118 |
+
batch_size: Number of points to sample per audio clip
|
119 |
+
normalize: Whether to normalize the audio values to [0, 1]
|
120 |
+
"""
|
121 |
+
self.dataset = original_dataset
|
122 |
+
self.max_len = max_len
|
123 |
+
self.dim = dim
|
124 |
+
self.normalize = normalize
|
125 |
+
self.sample_size = sample_size
|
126 |
+
|
127 |
+
def get_coordinates(self, audio_len):
|
128 |
+
"""Generate time coordinates"""
|
129 |
+
# Create normalized time coordinates in [0, 1]
|
130 |
+
coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim)
|
131 |
+
return coords # Shape: [audio_len, 1]
|
132 |
+
|
133 |
+
def sample_points(self, coords, values):
|
134 |
+
"""Randomly sample points from the audio"""
|
135 |
+
if len(coords) > self.sample_size:
|
136 |
+
idx = torch.randperm(len(coords))[:self.sample_size]
|
137 |
+
coords = coords[idx]
|
138 |
+
values = values[idx]
|
139 |
+
return coords, values
|
140 |
+
|
141 |
+
def __iter__(self):
|
142 |
+
for item in self.dataset:
|
143 |
+
# Get audio data
|
144 |
+
audio_data = torch.tensor(item['audio']['array']).float()
|
145 |
+
|
146 |
+
# Generate coordinates
|
147 |
+
coords = self.get_coordinates(len(audio_data))
|
148 |
+
|
149 |
+
item['audio']['coords'] = coords
|
150 |
+
|
151 |
+
# Sample random points
|
152 |
+
# coords, values = self.sample_points(coords, audio_data)
|
153 |
+
|
154 |
+
# Create the INR training sample
|
155 |
yield item
|
tasks/utils/data_utils.py
CHANGED
@@ -6,18 +6,25 @@ from torch.nn.utils.rnn import pad_sequence
|
|
6 |
def collate_fn(batch):
|
7 |
# Extract audio arrays and FFT data from the batch of dictionaries
|
8 |
audio_arrays = [torch.tensor(item['audio']['array']) for item in batch]
|
9 |
-
fft_arrays = [torch.tensor(item['audio']['
|
|
|
|
|
|
|
10 |
labels = [torch.tensor(item['label']) for item in batch]
|
11 |
|
12 |
# Pad both sequences
|
13 |
padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
|
14 |
padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
|
|
|
15 |
|
16 |
# Return as dictionary with the same structure
|
17 |
return {
|
18 |
'audio': {
|
19 |
'array': padded_audio,
|
20 |
-
'
|
|
|
|
|
|
|
21 |
},
|
22 |
'label': torch.stack(labels)
|
23 |
|
|
|
6 |
def collate_fn(batch):
|
7 |
# Extract audio arrays and FFT data from the batch of dictionaries
|
8 |
audio_arrays = [torch.tensor(item['audio']['array']) for item in batch]
|
9 |
+
fft_arrays = [torch.tensor(item['audio']['fft_mag']) for item in batch]
|
10 |
+
# cwt_arrays = [torch.tensor(item['audio']['cwt_mag']) for item in batch]
|
11 |
+
# features = [item['audio']['features'] for item in batch]
|
12 |
+
# features_arr = torch.stack([item['audio']['features_arr'] for item in batch])
|
13 |
labels = [torch.tensor(item['label']) for item in batch]
|
14 |
|
15 |
# Pad both sequences
|
16 |
padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
|
17 |
padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
|
18 |
+
# padded_features = pad_sequence(features_arr, batch_first=True, padding_value=0)
|
19 |
|
20 |
# Return as dictionary with the same structure
|
21 |
return {
|
22 |
'audio': {
|
23 |
'array': padded_audio,
|
24 |
+
'fft_mag': padded_fft,
|
25 |
+
# 'features': features,
|
26 |
+
# 'features_arr': features_arr,
|
27 |
+
# 'cwt_mag': padded_cwt,
|
28 |
},
|
29 |
'label': torch.stack(labels)
|
30 |
|
tasks/utils/dfs/test.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tasks/utils/dfs/train.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tasks/utils/dfs/val.csv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tasks/utils/graph_constructor.py
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from rff.layers import GaussianEncoding
|
4 |
+
|
5 |
+
# from nn.probe_features import GraphProbeFeatures
|
6 |
+
|
7 |
+
|
8 |
+
def sparsify_graph(edges, fraction=0.1):
|
9 |
+
abs_edges = torch.abs(edges)
|
10 |
+
flat_abs_tensor = abs_edges.flatten()
|
11 |
+
sorted_tensor, _ = torch.sort(flat_abs_tensor, descending=True)
|
12 |
+
num_elements = flat_abs_tensor.numel()
|
13 |
+
top_k = int(num_elements * fraction)
|
14 |
+
topk_values, topk_indices = torch.topk(flat_abs_tensor, top_k)
|
15 |
+
mask = torch.zeros_like(flat_abs_tensor, dtype=torch.bool)
|
16 |
+
mask[topk_indices] = True
|
17 |
+
mask = mask.view(edges.shape)
|
18 |
+
return mask
|
19 |
+
|
20 |
+
def batch_to_graphs(
|
21 |
+
weights,
|
22 |
+
biases,
|
23 |
+
weights_mean=None,
|
24 |
+
weights_std=None,
|
25 |
+
biases_mean=None,
|
26 |
+
biases_std=None,
|
27 |
+
sparsify=False,
|
28 |
+
sym_edges=False
|
29 |
+
):
|
30 |
+
device = weights[0].device
|
31 |
+
bsz = weights[0].shape[0]
|
32 |
+
num_nodes = weights[0].shape[1] + sum(w.shape[2] for w in weights)
|
33 |
+
|
34 |
+
node_features = torch.zeros(bsz, num_nodes, biases[0].shape[-1], device=device)
|
35 |
+
edge_features = torch.zeros(
|
36 |
+
bsz, num_nodes, num_nodes, weights[0].shape[-1], device=device
|
37 |
+
)
|
38 |
+
|
39 |
+
row_offset = 0
|
40 |
+
col_offset = weights[0].shape[1] # no edge to input nodes
|
41 |
+
|
42 |
+
for i, w in enumerate(weights):
|
43 |
+
_, num_in, num_out, _ = w.shape
|
44 |
+
w_mean = weights_mean[i] if weights_mean is not None else 0
|
45 |
+
w_std = weights_std[i] if weights_std is not None else 1
|
46 |
+
w = (w - w_mean) / w_std
|
47 |
+
if sparsify:
|
48 |
+
w[~sparsify_graph(w)] = 0
|
49 |
+
edge_features[
|
50 |
+
:, row_offset : row_offset + num_in, col_offset : col_offset + num_out
|
51 |
+
] = w
|
52 |
+
if sym_edges:
|
53 |
+
edge_features[
|
54 |
+
:, col_offset: col_offset + num_out, row_offset: row_offset + num_in
|
55 |
+
] = torch.swapaxes(w, 1,2)
|
56 |
+
row_offset += num_in
|
57 |
+
col_offset += num_out
|
58 |
+
|
59 |
+
row_offset = weights[0].shape[1] # no bias in input nodes
|
60 |
+
for i, b in enumerate(biases):
|
61 |
+
_, num_out, _ = b.shape
|
62 |
+
b_mean = biases_mean[i] if biases_mean is not None else 0
|
63 |
+
b_std = biases_std[i] if biases_std is not None else 1
|
64 |
+
node_features[:, row_offset : row_offset + num_out] = (b - b_mean) / b_std
|
65 |
+
row_offset += num_out
|
66 |
+
|
67 |
+
return node_features, edge_features
|
68 |
+
|
69 |
+
|
70 |
+
class GraphConstructor(nn.Module):
|
71 |
+
def __init__(
|
72 |
+
self,
|
73 |
+
d_in,
|
74 |
+
d_edge_in,
|
75 |
+
d_node,
|
76 |
+
d_edge,
|
77 |
+
layer_layout,
|
78 |
+
rev_edge_features=False,
|
79 |
+
zero_out_bias=False,
|
80 |
+
zero_out_weights=False,
|
81 |
+
inp_factor=1,
|
82 |
+
input_layers=1,
|
83 |
+
sin_emb=False,
|
84 |
+
sin_emb_dim=128,
|
85 |
+
use_pos_embed=False,
|
86 |
+
num_probe_features=0,
|
87 |
+
inr_model=None,
|
88 |
+
stats=None,
|
89 |
+
sparsify=False,
|
90 |
+
sym_edges=False,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
self.rev_edge_features = rev_edge_features
|
94 |
+
self.nodes_per_layer = layer_layout
|
95 |
+
self.zero_out_bias = zero_out_bias
|
96 |
+
self.zero_out_weights = zero_out_weights
|
97 |
+
self.use_pos_embed = use_pos_embed
|
98 |
+
self.stats = stats if stats is not None else {}
|
99 |
+
self._d_node = d_node
|
100 |
+
self._d_edge = d_edge
|
101 |
+
self.sparse = sparsify
|
102 |
+
self.sym_edges = sym_edges
|
103 |
+
|
104 |
+
self.pos_embed_layout = (
|
105 |
+
[1] * layer_layout[0] + layer_layout[1:-1] + [1] * layer_layout[-1]
|
106 |
+
)
|
107 |
+
self.pos_embed = nn.Parameter(torch.randn(len(self.pos_embed_layout), d_node))
|
108 |
+
|
109 |
+
if not self.zero_out_weights:
|
110 |
+
proj_weight = []
|
111 |
+
if sin_emb:
|
112 |
+
proj_weight.append(
|
113 |
+
GaussianEncoding(
|
114 |
+
sigma=inp_factor,
|
115 |
+
input_size=d_edge_in
|
116 |
+
+ (2 * d_edge_in if rev_edge_features else 0),
|
117 |
+
encoded_size=sin_emb_dim,
|
118 |
+
)
|
119 |
+
)
|
120 |
+
proj_weight.append(nn.Linear(2 * sin_emb_dim, d_edge))
|
121 |
+
else:
|
122 |
+
proj_weight.append(
|
123 |
+
nn.Linear(
|
124 |
+
d_edge_in + (2 * d_edge_in if rev_edge_features else 0), d_edge
|
125 |
+
)
|
126 |
+
)
|
127 |
+
|
128 |
+
for i in range(input_layers - 1):
|
129 |
+
proj_weight.append(nn.SiLU())
|
130 |
+
proj_weight.append(nn.Linear(d_edge, d_edge))
|
131 |
+
|
132 |
+
self.proj_weight = nn.Sequential(*proj_weight)
|
133 |
+
if not self.zero_out_bias:
|
134 |
+
proj_bias = []
|
135 |
+
if sin_emb:
|
136 |
+
proj_bias.append(
|
137 |
+
GaussianEncoding(
|
138 |
+
sigma=inp_factor,
|
139 |
+
input_size=d_in,
|
140 |
+
encoded_size=sin_emb_dim,
|
141 |
+
)
|
142 |
+
)
|
143 |
+
proj_bias.append(nn.Linear(2 * sin_emb_dim, d_node))
|
144 |
+
else:
|
145 |
+
proj_bias.append(nn.Linear(d_in, d_node))
|
146 |
+
|
147 |
+
for i in range(input_layers - 1):
|
148 |
+
proj_bias.append(nn.SiLU())
|
149 |
+
proj_bias.append(nn.Linear(d_node, d_node))
|
150 |
+
|
151 |
+
self.proj_bias = nn.Sequential(*proj_bias)
|
152 |
+
|
153 |
+
self.proj_node_in = nn.Linear(d_node, d_node)
|
154 |
+
self.proj_edge_in = nn.Linear(d_edge, d_edge)
|
155 |
+
|
156 |
+
if num_probe_features > 0:
|
157 |
+
self.gpf = GraphProbeFeatures(
|
158 |
+
d_in=layer_layout[0],
|
159 |
+
num_inputs=num_probe_features,
|
160 |
+
inr_model=inr_model,
|
161 |
+
input_init=None,
|
162 |
+
proj_dim=d_node,
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
self.gpf = None
|
166 |
+
|
167 |
+
def forward(self, inputs):
|
168 |
+
node_features, edge_features = batch_to_graphs(*inputs, **self.stats,
|
169 |
+
)
|
170 |
+
mask = edge_features.sum(dim=-1, keepdim=True) != 0
|
171 |
+
if self.rev_edge_features:
|
172 |
+
rev_edge_features = edge_features.transpose(-2, -3)
|
173 |
+
edge_features = torch.cat(
|
174 |
+
[edge_features, rev_edge_features, edge_features + rev_edge_features],
|
175 |
+
dim=-1,
|
176 |
+
)
|
177 |
+
mask = mask | mask.transpose(-3, -2)
|
178 |
+
|
179 |
+
if self.zero_out_weights:
|
180 |
+
edge_features = torch.zeros(
|
181 |
+
(*edge_features.shape[:-1], self._d_edge),
|
182 |
+
device=edge_features.device,
|
183 |
+
dtype=edge_features.dtype,
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
edge_features = self.proj_weight(edge_features)
|
187 |
+
if self.zero_out_bias:
|
188 |
+
# only zero out bias, not gpf
|
189 |
+
node_features = torch.zeros(
|
190 |
+
(*node_features.shape[:-1], self._d_node),
|
191 |
+
device=node_features.device,
|
192 |
+
dtype=node_features.dtype,
|
193 |
+
)
|
194 |
+
else:
|
195 |
+
node_features = self.proj_bias(node_features)
|
196 |
+
|
197 |
+
if self.gpf is not None:
|
198 |
+
probe_features = self.gpf(*inputs)
|
199 |
+
node_features = node_features + probe_features
|
200 |
+
|
201 |
+
node_features = self.proj_node_in(node_features)
|
202 |
+
edge_features = self.proj_edge_in(edge_features)
|
203 |
+
|
204 |
+
if self.use_pos_embed:
|
205 |
+
pos_embed = torch.cat(
|
206 |
+
[
|
207 |
+
# repeat(self.pos_embed[i], "d -> 1 n d", n=n)
|
208 |
+
self.pos_embed[i].unsqueeze(0).expand(1, n, -1)
|
209 |
+
for i, n in enumerate(self.pos_embed_layout)
|
210 |
+
],
|
211 |
+
dim=1,
|
212 |
+
)
|
213 |
+
node_features = node_features + pos_embed
|
214 |
+
return node_features, edge_features, mask
|
tasks/utils/inr.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
from typing import Optional
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from rff.layers import GaussianEncoding, PositionalEncoding
|
8 |
+
from torch import nn
|
9 |
+
from .kan.fasterkan import FasterKAN
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
class Sine(nn.Module):
|
14 |
+
def __init__(self, w0=1.0):
|
15 |
+
super().__init__()
|
16 |
+
self.w0 = w0
|
17 |
+
|
18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
19 |
+
return torch.sin(self.w0 * x)
|
20 |
+
|
21 |
+
|
22 |
+
def params_to_tensor(params):
|
23 |
+
return torch.cat([p.flatten() for p in params]), [p.shape for p in params]
|
24 |
+
|
25 |
+
|
26 |
+
def tensor_to_params(tensor, shapes):
|
27 |
+
params = []
|
28 |
+
start = 0
|
29 |
+
for shape in shapes:
|
30 |
+
size = torch.prod(torch.tensor(shape)).item()
|
31 |
+
param = tensor[start : start + size].reshape(shape)
|
32 |
+
params.append(param)
|
33 |
+
start += size
|
34 |
+
return tuple(params)
|
35 |
+
|
36 |
+
|
37 |
+
def wrap_func(func, shapes):
|
38 |
+
def wrapped_func(params, *args, **kwargs):
|
39 |
+
params = tensor_to_params(params, shapes)
|
40 |
+
return func(params, *args, **kwargs)
|
41 |
+
|
42 |
+
return wrapped_func
|
43 |
+
|
44 |
+
|
45 |
+
class Siren(nn.Module):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
dim_in,
|
49 |
+
dim_out,
|
50 |
+
w0=30.0,
|
51 |
+
c=6.0,
|
52 |
+
is_first=False,
|
53 |
+
use_bias=True,
|
54 |
+
activation=None,
|
55 |
+
):
|
56 |
+
super().__init__()
|
57 |
+
self.w0 = w0
|
58 |
+
self.c = c
|
59 |
+
self.dim_in = dim_in
|
60 |
+
self.dim_out = dim_out
|
61 |
+
self.is_first = is_first
|
62 |
+
|
63 |
+
weight = torch.zeros(dim_out, dim_in)
|
64 |
+
bias = torch.zeros(dim_out) if use_bias else None
|
65 |
+
self.init_(weight, bias, c=c, w0=w0)
|
66 |
+
|
67 |
+
self.weight = nn.Parameter(weight)
|
68 |
+
self.bias = nn.Parameter(bias) if use_bias else None
|
69 |
+
self.activation = Sine(w0) if activation is None else activation
|
70 |
+
|
71 |
+
def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
|
72 |
+
dim = self.dim_in
|
73 |
+
|
74 |
+
w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
|
75 |
+
weight.uniform_(-w_std, w_std)
|
76 |
+
|
77 |
+
if bias is not None:
|
78 |
+
# bias.uniform_(-w_std, w_std)
|
79 |
+
bias.zero_()
|
80 |
+
|
81 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
82 |
+
out = F.linear(x, self.weight, self.bias)
|
83 |
+
out = self.activation(out)
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class INR(nn.Module):
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_features: int = 2,
|
91 |
+
n_layers: int = 3,
|
92 |
+
hidden_features: int = 32,
|
93 |
+
out_features: int = 1,
|
94 |
+
pe_features: Optional[int] = None,
|
95 |
+
fix_pe=True,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
if pe_features is not None:
|
100 |
+
if fix_pe:
|
101 |
+
self.layers = [PositionalEncoding(sigma=10, m=pe_features)]
|
102 |
+
encoded_dim = in_features * pe_features * 2
|
103 |
+
else:
|
104 |
+
self.layers = [
|
105 |
+
GaussianEncoding(
|
106 |
+
sigma=10, input_size=in_features, encoded_size=pe_features
|
107 |
+
)
|
108 |
+
]
|
109 |
+
encoded_dim = pe_features * 2
|
110 |
+
self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features))
|
111 |
+
else:
|
112 |
+
self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)]
|
113 |
+
for i in range(n_layers - 2):
|
114 |
+
self.layers.append(Siren(hidden_features, hidden_features))
|
115 |
+
self.layers.append(nn.Linear(hidden_features, out_features))
|
116 |
+
self.seq = nn.Sequential(*self.layers)
|
117 |
+
self.num_layers = len(self.layers)
|
118 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
119 |
+
return self.seq(x) + 0.5
|
120 |
+
|
121 |
+
|
122 |
+
class INRPerLayer(INR):
|
123 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
124 |
+
nodes = [x]
|
125 |
+
for layer in self.seq:
|
126 |
+
nodes.append(layer(nodes[-1]))
|
127 |
+
nodes[-1] = nodes[-1] + 0.5
|
128 |
+
return nodes
|
129 |
+
|
130 |
+
|
131 |
+
def make_functional(mod, disable_autograd_tracking=False):
|
132 |
+
params_dict = dict(mod.named_parameters())
|
133 |
+
params_names = params_dict.keys()
|
134 |
+
params_values = tuple(params_dict.values())
|
135 |
+
|
136 |
+
stateless_mod = copy.deepcopy(mod)
|
137 |
+
stateless_mod.to("meta")
|
138 |
+
|
139 |
+
def fmodel(new_params_values, *args, **kwargs):
|
140 |
+
new_params_dict = {
|
141 |
+
name: value for name, value in zip(params_names, new_params_values)
|
142 |
+
}
|
143 |
+
return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
|
144 |
+
|
145 |
+
if disable_autograd_tracking:
|
146 |
+
params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
|
147 |
+
return fmodel, params_values
|
tasks/utils/models.py
CHANGED
@@ -4,11 +4,56 @@ from .Modules.conformer import ConformerEncoder, ConformerDecoder
|
|
4 |
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
5 |
from .kan.fasterkan import FasterKAN
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
class ConvBlock(nn.Module):
|
8 |
def __init__(self, args, num_layer) -> None:
|
9 |
super().__init__()
|
10 |
if args.activation == 'silu':
|
11 |
self.activation = nn.SiLU()
|
|
|
|
|
12 |
else:
|
13 |
self.activation = nn.ReLU()
|
14 |
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
|
@@ -31,6 +76,8 @@ class CNNEncoder(nn.Module):
|
|
31 |
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
|
32 |
if args.activation == 'silu':
|
33 |
self.activation = nn.SiLU()
|
|
|
|
|
34 |
else:
|
35 |
self.activation = nn.ReLU()
|
36 |
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
|
@@ -125,6 +172,21 @@ class CNNKan(nn.Module):
|
|
125 |
x = x.mean(dim=1)
|
126 |
return self.kan(x)
|
127 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
128 |
class KanEncoder(nn.Module):
|
129 |
def __init__(self, args):
|
130 |
super().__init__()
|
@@ -138,3 +200,44 @@ class KanEncoder(nn.Module):
|
|
138 |
out = torch.cat([x, f], dim=-1)
|
139 |
return self.kan_out(out)
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
|
5 |
from .kan.fasterkan import FasterKAN
|
6 |
|
7 |
+
|
8 |
+
class Sine(nn.Module):
|
9 |
+
def __init__(self, w0=1.0):
|
10 |
+
super().__init__()
|
11 |
+
self.w0 = w0
|
12 |
+
|
13 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
14 |
+
return torch.sin(self.w0 * x)
|
15 |
+
|
16 |
+
|
17 |
+
class MLPEncoder(nn.Module):
|
18 |
+
def __init__(self, args):
|
19 |
+
"""
|
20 |
+
Initialize an MLP with hidden layers, BatchNorm, and Dropout.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
input_dim (int): Dimension of the input features.
|
24 |
+
hidden_dims (list of int): List of dimensions for hidden layers.
|
25 |
+
output_dim (int): Dimension of the output.
|
26 |
+
dropout (float): Dropout probability (default: 0.0).
|
27 |
+
"""
|
28 |
+
super(MLPEncoder, self).__init__()
|
29 |
+
|
30 |
+
layers = []
|
31 |
+
prev_dim = args.input_dim
|
32 |
+
|
33 |
+
# Add hidden layers
|
34 |
+
for hidden_dim in args.hidden_dims:
|
35 |
+
layers.append(nn.Linear(prev_dim, hidden_dim))
|
36 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
37 |
+
layers.append(nn.SiLU())
|
38 |
+
if args.dropout > 0.0:
|
39 |
+
layers.append(nn.Dropout(args.dropout))
|
40 |
+
prev_dim = hidden_dim
|
41 |
+
self.model = nn.Sequential(*layers)
|
42 |
+
self.output_dim = hidden_dim
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
# if x.dim() == 2:
|
46 |
+
# x = x.unsqueeze(-1)
|
47 |
+
x = self.model(x)
|
48 |
+
# x = x.mean(-1)
|
49 |
+
return x
|
50 |
class ConvBlock(nn.Module):
|
51 |
def __init__(self, args, num_layer) -> None:
|
52 |
super().__init__()
|
53 |
if args.activation == 'silu':
|
54 |
self.activation = nn.SiLU()
|
55 |
+
elif args.activation == 'sine':
|
56 |
+
self.activation = Sine(w0=args.sine_w0)
|
57 |
else:
|
58 |
self.activation = nn.ReLU()
|
59 |
in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
|
|
|
76 |
print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
|
77 |
if args.activation == 'silu':
|
78 |
self.activation = nn.SiLU()
|
79 |
+
elif args.activation == 'sine':
|
80 |
+
self.activation = Sine(w0=args.sine_w0)
|
81 |
else:
|
82 |
self.activation = nn.ReLU()
|
83 |
self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
|
|
|
172 |
x = x.mean(dim=1)
|
173 |
return self.kan(x)
|
174 |
|
175 |
+
class CNNKanFeaturesEncoder(nn.Module):
|
176 |
+
def __init__(self, args, mlp_args, kan_args):
|
177 |
+
super().__init__()
|
178 |
+
self.backbone = CNNEncoder(args)
|
179 |
+
self.mlp = MLPEncoder(mlp_args)
|
180 |
+
kan_args['layers_hidden'][0] += self.mlp.output_dim
|
181 |
+
self.kan = FasterKAN(**kan_args)
|
182 |
+
|
183 |
+
def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
|
184 |
+
x = self.backbone(x)
|
185 |
+
x = x.mean(dim=1)
|
186 |
+
f = self.mlp(f)
|
187 |
+
x_f = torch.cat([x, f], dim=-1)
|
188 |
+
return self.kan(x_f)
|
189 |
+
|
190 |
class KanEncoder(nn.Module):
|
191 |
def __init__(self, args):
|
192 |
super().__init__()
|
|
|
200 |
out = torch.cat([x, f], dim=-1)
|
201 |
return self.kan_out(out)
|
202 |
|
203 |
+
|
204 |
+
class MultiGraph(nn.Module):
|
205 |
+
def __init__(self, graph_net, args):
|
206 |
+
super().__init__()
|
207 |
+
self.graph_net = graph_net
|
208 |
+
self.cnn = CNNEncoder(args)
|
209 |
+
total_output_dim = args.encoder_dims[-1]
|
210 |
+
self.projection = nn.Sequential(
|
211 |
+
nn.Linear(total_output_dim, total_output_dim // 2),
|
212 |
+
nn.BatchNorm1d(total_output_dim // 2),
|
213 |
+
nn.SiLU(),
|
214 |
+
nn.Linear(total_output_dim // 2, 1)
|
215 |
+
)
|
216 |
+
|
217 |
+
def forward(self, g: torch.Tensor, x:torch.Tensor) -> torch.Tensor:
|
218 |
+
# g_out = self.graph_net(g)
|
219 |
+
x_out = self.cnn(x)
|
220 |
+
# g_out = g_out.expand(x.shape[0], -1)
|
221 |
+
# features = torch.cat([g_out, x_out], dim=-1)
|
222 |
+
return self.projection(x_out)
|
223 |
+
|
224 |
+
class ImplicitEncoder(nn.Module):
|
225 |
+
def __init__(self, transform_net, encoder_net):
|
226 |
+
super().__init__()
|
227 |
+
self.transform_net = transform_net
|
228 |
+
self.encoder_net = encoder_net
|
229 |
+
|
230 |
+
def get_weights_and_bises(self):
|
231 |
+
state_dict = self.transform_net.state_dict()
|
232 |
+
weights = tuple(
|
233 |
+
[v.permute(1, 0).unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "weight" in w]
|
234 |
+
)
|
235 |
+
biases = tuple([v.unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "bias" in w])
|
236 |
+
return weights, biases
|
237 |
+
|
238 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
239 |
+
transformed_x = self.transform_net(x.permute(0, 2, 1)).permute(0, 2, 1)
|
240 |
+
inputs = self.get_weights_and_bises()
|
241 |
+
outputs = self.encoder_net(inputs, transformed_x)
|
242 |
+
return outputs
|
243 |
+
|
tasks/utils/pooling.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch_geometric.nn.aggr import (
|
4 |
+
AttentionalAggregation,
|
5 |
+
GraphMultisetTransformer,
|
6 |
+
MaxAggregation,
|
7 |
+
MeanAggregation,
|
8 |
+
SetTransformerAggregation,
|
9 |
+
)
|
10 |
+
|
11 |
+
class CatAggregation(nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.flatten = nn.Flatten(1, 2)
|
15 |
+
|
16 |
+
def forward(self, x, index=None):
|
17 |
+
return self.flatten(x)
|
18 |
+
|
19 |
+
|
20 |
+
class HeterogeneousAggregator(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
input_dim,
|
24 |
+
hidden_dim,
|
25 |
+
output_dim,
|
26 |
+
pooling_method,
|
27 |
+
pooling_layer_idx,
|
28 |
+
input_channels,
|
29 |
+
num_classes,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
self.pooling_method = pooling_method
|
33 |
+
self.pooling_layer_idx = pooling_layer_idx
|
34 |
+
self.input_channels = input_channels
|
35 |
+
self.num_classes = num_classes
|
36 |
+
|
37 |
+
if pooling_layer_idx == "all":
|
38 |
+
self._pool_layer_idx_fn = self.get_all_layer_indices
|
39 |
+
elif pooling_layer_idx == "last":
|
40 |
+
self._pool_layer_idx_fn = self.get_last_layer_indices
|
41 |
+
elif isinstance(pooling_layer_idx, int):
|
42 |
+
self._pool_layer_idx_fn = self.get_nth_layer_indices
|
43 |
+
else:
|
44 |
+
raise ValueError(f"Unknown pooling layer index {pooling_layer_idx}")
|
45 |
+
|
46 |
+
if pooling_method == "mean":
|
47 |
+
self.pool = MeanAggregation()
|
48 |
+
elif pooling_method == "max":
|
49 |
+
self.pool = MaxAggregation()
|
50 |
+
elif pooling_method == "cat":
|
51 |
+
self.pool = CatAggregation()
|
52 |
+
elif pooling_method == "attentional_aggregation":
|
53 |
+
self.pool = AttentionalAggregation(
|
54 |
+
gate_nn=nn.Sequential(
|
55 |
+
nn.Linear(input_dim, hidden_dim),
|
56 |
+
nn.SiLU(),
|
57 |
+
nn.Linear(hidden_dim, 1),
|
58 |
+
),
|
59 |
+
nn=nn.Sequential(
|
60 |
+
nn.Linear(input_dim, hidden_dim),
|
61 |
+
nn.SiLU(),
|
62 |
+
nn.Linear(hidden_dim, output_dim),
|
63 |
+
),
|
64 |
+
)
|
65 |
+
elif pooling_method == "set_transformer":
|
66 |
+
self.pool = SetTransformerAggregation(
|
67 |
+
input_dim, heads=8, num_encoder_blocks=4, num_decoder_blocks=4
|
68 |
+
)
|
69 |
+
elif pooling_method == "graph_multiset_transformer":
|
70 |
+
self.pool = GraphMultisetTransformer(input_dim, k=8, heads=8)
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Unknown pooling method {pooling_method}")
|
73 |
+
|
74 |
+
def get_last_layer_indices(
|
75 |
+
self, x, layer_layouts, node_mask=None, return_dense=False
|
76 |
+
):
|
77 |
+
batch_size = x.shape[0]
|
78 |
+
device = x.device
|
79 |
+
|
80 |
+
# NOTE: node_mask needs to exist in the heterogeneous case only
|
81 |
+
if node_mask is None:
|
82 |
+
node_mask = torch.ones_like(x[..., 0], dtype=torch.bool, device=device)
|
83 |
+
|
84 |
+
valid_layer_indices = (
|
85 |
+
torch.arange(node_mask.shape[1], device=device)[None, :] * node_mask
|
86 |
+
)
|
87 |
+
last_layer_indices = valid_layer_indices.topk(
|
88 |
+
k=self.num_classes, dim=1
|
89 |
+
).values.fliplr()
|
90 |
+
|
91 |
+
if return_dense:
|
92 |
+
return torch.arange(batch_size, device=device)[:, None], last_layer_indices
|
93 |
+
|
94 |
+
batch_indices = torch.arange(batch_size, device=device).repeat_interleave(
|
95 |
+
self.num_classes
|
96 |
+
)
|
97 |
+
return batch_indices, last_layer_indices.flatten()
|
98 |
+
|
99 |
+
def get_nth_layer_indices(
|
100 |
+
self, x, layer_layouts, node_mask=None, return_dense=False
|
101 |
+
):
|
102 |
+
batch_size = x.shape[0]
|
103 |
+
device = x.device
|
104 |
+
|
105 |
+
cum_layer_layout = [
|
106 |
+
torch.cumsum(torch.tensor([0] + layer_layout), dim=0)
|
107 |
+
for layer_layout in layer_layouts
|
108 |
+
]
|
109 |
+
|
110 |
+
layer_sizes = torch.tensor(
|
111 |
+
[layer_layout[self.pooling_layer_idx] for layer_layout in layer_layouts],
|
112 |
+
dtype=torch.long,
|
113 |
+
device=device,
|
114 |
+
)
|
115 |
+
batch_indices = torch.arange(batch_size, device=device).repeat_interleave(
|
116 |
+
layer_sizes
|
117 |
+
)
|
118 |
+
layer_indices = torch.cat(
|
119 |
+
[
|
120 |
+
torch.arange(
|
121 |
+
layout[self.pooling_layer_idx],
|
122 |
+
layout[self.pooling_layer_idx + 1],
|
123 |
+
device=device,
|
124 |
+
)
|
125 |
+
for layout in cum_layer_layout
|
126 |
+
]
|
127 |
+
)
|
128 |
+
return batch_indices, layer_indices
|
129 |
+
|
130 |
+
def get_all_layer_indices(
|
131 |
+
self, x, layer_layouts, node_mask=None, return_dense=False
|
132 |
+
):
|
133 |
+
"""Imitate flattening with indexing"""
|
134 |
+
batch_size, num_nodes = x.shape[:2]
|
135 |
+
device = x.device
|
136 |
+
batch_indices = torch.arange(batch_size, device=device).repeat_interleave(
|
137 |
+
num_nodes
|
138 |
+
)
|
139 |
+
layer_indices = torch.arange(num_nodes, device=device).repeat(batch_size)
|
140 |
+
return batch_indices, layer_indices
|
141 |
+
|
142 |
+
def forward(self, x, layer_layouts, node_mask=None):
|
143 |
+
# NOTE: `cat` only works with `pooling_layer_idx == "last"`
|
144 |
+
return_dense = self.pooling_method == "cat" and self.pooling_layer_idx == "last"
|
145 |
+
batch_indices, layer_indices = self._pool_layer_idx_fn(
|
146 |
+
x, layer_layouts, node_mask=node_mask, return_dense=return_dense
|
147 |
+
)
|
148 |
+
|
149 |
+
flat_x = x[batch_indices, layer_indices]
|
150 |
+
return self.pool(flat_x, index=batch_indices)
|
151 |
+
|
152 |
+
|
153 |
+
class HomogeneousAggregator(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
pooling_method,
|
157 |
+
pooling_layer_idx,
|
158 |
+
layer_layout,
|
159 |
+
):
|
160 |
+
super().__init__()
|
161 |
+
self.pooling_method = pooling_method
|
162 |
+
self.pooling_layer_idx = pooling_layer_idx
|
163 |
+
self.layer_layout = layer_layout
|
164 |
+
|
165 |
+
def forward(self, node_features, edge_features):
|
166 |
+
if self.pooling_method == "mean" and self.pooling_layer_idx == "all":
|
167 |
+
graph_features = node_features.mean(dim=1)
|
168 |
+
elif self.pooling_method == "max" and self.pooling_layer_idx == "all":
|
169 |
+
graph_features = node_features.max(dim=1).values
|
170 |
+
elif self.pooling_method == "mean" and self.pooling_layer_idx == "last":
|
171 |
+
graph_features = node_features[:, -self.layer_layout[-1] :].mean(dim=1)
|
172 |
+
elif self.pooling_method == "cat" and self.pooling_layer_idx == "last":
|
173 |
+
graph_features = node_features[:, -self.layer_layout[-1] :].flatten(1, 2)
|
174 |
+
elif self.pooling_method == "mean" and isinstance(self.pooling_layer_idx, int):
|
175 |
+
graph_features = node_features[
|
176 |
+
:,
|
177 |
+
self.layer_idx[self.pooling_layer_idx] : self.layer_idx[
|
178 |
+
self.pooling_layer_idx + 1
|
179 |
+
],
|
180 |
+
].mean(dim=1)
|
181 |
+
elif self.pooling_method == "cat_mean" and self.pooling_layer_idx == "all":
|
182 |
+
graph_features = torch.cat(
|
183 |
+
[
|
184 |
+
node_features[:, self.layer_idx[i] : self.layer_idx[i + 1]].mean(
|
185 |
+
dim=1
|
186 |
+
)
|
187 |
+
for i in range(len(self.layer_layout))
|
188 |
+
],
|
189 |
+
dim=1,
|
190 |
+
)
|
191 |
+
elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "all":
|
192 |
+
graph_features = edge_features.mean(dim=(1, 2))
|
193 |
+
elif self.pooling_method == "max_edge" and self.pooling_layer_idx == "all":
|
194 |
+
graph_features = edge_features.flatten(1, 2).max(dim=1).values
|
195 |
+
elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "last":
|
196 |
+
graph_features = edge_features[:, :, -self.layer_layout[-1] :].mean(
|
197 |
+
dim=(1, 2)
|
198 |
+
)
|
199 |
+
return graph_features
|
tasks/utils/probe_features.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
from einops.layers.torch import Rearrange
|
5 |
+
|
6 |
+
from nn.inr import make_functional, params_to_tensor, wrap_func
|
7 |
+
|
8 |
+
|
9 |
+
class GraphProbeFeatures(nn.Module):
|
10 |
+
def __init__(self, d_in, num_inputs, inr_model, input_init=None, proj_dim=None):
|
11 |
+
super().__init__()
|
12 |
+
inr = hydra.utils.instantiate(inr_model)
|
13 |
+
fmodel, params = make_functional(inr)
|
14 |
+
|
15 |
+
vparams, vshapes = params_to_tensor(params)
|
16 |
+
self.sirens = torch.vmap(wrap_func(fmodel, vshapes))
|
17 |
+
|
18 |
+
inputs = (
|
19 |
+
input_init
|
20 |
+
if input_init is not None
|
21 |
+
else 2 * torch.rand(1, num_inputs, d_in) - 1
|
22 |
+
)
|
23 |
+
self.inputs = nn.Parameter(inputs, requires_grad=input_init is None)
|
24 |
+
|
25 |
+
self.reshape_weights = Rearrange("b i o 1 -> b (o i)")
|
26 |
+
self.reshape_biases = Rearrange("b o 1 -> b o")
|
27 |
+
|
28 |
+
self.proj_dim = proj_dim
|
29 |
+
if proj_dim is not None:
|
30 |
+
self.proj = nn.ModuleList(
|
31 |
+
[
|
32 |
+
nn.Sequential(
|
33 |
+
nn.Linear(num_inputs, proj_dim),
|
34 |
+
nn.LayerNorm(proj_dim),
|
35 |
+
)
|
36 |
+
for _ in range(inr.num_layers + 1)
|
37 |
+
]
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, weights, biases):
|
41 |
+
weights = [self.reshape_weights(w) for w in weights]
|
42 |
+
biases = [self.reshape_biases(b) for b in biases]
|
43 |
+
params_flat = torch.cat(
|
44 |
+
[w_or_b for p in zip(weights, biases) for w_or_b in p], dim=-1
|
45 |
+
)
|
46 |
+
|
47 |
+
out = self.sirens(params_flat, self.inputs.expand(params_flat.shape[0], -1, -1))
|
48 |
+
if self.proj_dim is not None:
|
49 |
+
out = [proj(out[i].permute(0, 2, 1)) for i, proj in enumerate(self.proj)]
|
50 |
+
out = torch.cat(out, dim=1)
|
51 |
+
return out
|
52 |
+
else:
|
53 |
+
out = torch.cat(out, dim=-1)
|
54 |
+
return out.permute(0, 2, 1)
|
tasks/utils/relational_transformer.py
ADDED
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops.layers.torch import Rearrange
|
5 |
+
|
6 |
+
from utils.pooling import HomogeneousAggregator
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
|
10 |
+
class RelationalTransformer(nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
d_node,
|
14 |
+
d_edge,
|
15 |
+
d_attn_hid,
|
16 |
+
d_node_hid,
|
17 |
+
d_edge_hid,
|
18 |
+
d_out_hid,
|
19 |
+
d_out,
|
20 |
+
n_layers,
|
21 |
+
n_heads,
|
22 |
+
layer_layout,
|
23 |
+
graph_constructor,
|
24 |
+
dropout=0.0,
|
25 |
+
node_update_type="rt",
|
26 |
+
disable_edge_updates=False,
|
27 |
+
use_cls_token=False,
|
28 |
+
pooling_method="cat",
|
29 |
+
pooling_layer_idx="last",
|
30 |
+
rev_edge_features=False,
|
31 |
+
modulate_v=True,
|
32 |
+
use_ln=True,
|
33 |
+
tfixit_init=False,
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
assert use_cls_token == (pooling_method == "cls_token")
|
37 |
+
self.pooling_method = pooling_method
|
38 |
+
self.pooling_layer_idx = pooling_layer_idx
|
39 |
+
self.rev_edge_features = rev_edge_features
|
40 |
+
self.nodes_per_layer = layer_layout
|
41 |
+
self.construct_graph = hydra.utils.instantiate(
|
42 |
+
graph_constructor,
|
43 |
+
d_node=d_node,
|
44 |
+
d_edge=d_edge,
|
45 |
+
layer_layout=layer_layout,
|
46 |
+
rev_edge_features=rev_edge_features,
|
47 |
+
)
|
48 |
+
|
49 |
+
self.use_cls_token = use_cls_token
|
50 |
+
if use_cls_token:
|
51 |
+
self.cls_token = nn.Parameter(torch.randn(d_node))
|
52 |
+
|
53 |
+
self.layers = nn.ModuleList(
|
54 |
+
[
|
55 |
+
torch.jit.script(
|
56 |
+
RTLayer(
|
57 |
+
d_node,
|
58 |
+
d_edge,
|
59 |
+
d_attn_hid,
|
60 |
+
d_node_hid,
|
61 |
+
d_edge_hid,
|
62 |
+
n_heads,
|
63 |
+
dropout,
|
64 |
+
node_update_type=node_update_type,
|
65 |
+
disable_edge_updates=(
|
66 |
+
(disable_edge_updates or (i == n_layers - 1))
|
67 |
+
and pooling_method != "mean_edge"
|
68 |
+
and pooling_layer_idx != "all"
|
69 |
+
),
|
70 |
+
modulate_v=modulate_v,
|
71 |
+
use_ln=use_ln,
|
72 |
+
tfixit_init=tfixit_init,
|
73 |
+
n_layers=n_layers,
|
74 |
+
)
|
75 |
+
)
|
76 |
+
for i in range(n_layers)
|
77 |
+
]
|
78 |
+
)
|
79 |
+
|
80 |
+
if pooling_method != "cls_token":
|
81 |
+
self.pool = HomogeneousAggregator(
|
82 |
+
pooling_method,
|
83 |
+
pooling_layer_idx,
|
84 |
+
layer_layout,
|
85 |
+
)
|
86 |
+
|
87 |
+
self.num_graph_features = (
|
88 |
+
layer_layout[-1] * d_node
|
89 |
+
if pooling_method == "cat" and pooling_layer_idx == "last"
|
90 |
+
else d_edge if pooling_method in ("mean_edge", "max_edge") else d_node
|
91 |
+
)
|
92 |
+
self.proj_out = nn.Sequential(
|
93 |
+
nn.Linear(self.num_graph_features, d_out_hid),
|
94 |
+
nn.ReLU(),
|
95 |
+
# nn.Linear(d_out_hid, d_out_hid),
|
96 |
+
# nn.ReLU(),
|
97 |
+
nn.Linear(d_out_hid, d_out),
|
98 |
+
)
|
99 |
+
|
100 |
+
self.final_features = (None,None,None,None)
|
101 |
+
|
102 |
+
def forward(self, inputs):
|
103 |
+
attn_weights = None
|
104 |
+
node_features, edge_features, mask = self.construct_graph(inputs)
|
105 |
+
if self.use_cls_token:
|
106 |
+
node_features = torch.cat(
|
107 |
+
[
|
108 |
+
# repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)),
|
109 |
+
self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1),
|
110 |
+
node_features,
|
111 |
+
],
|
112 |
+
dim=1,
|
113 |
+
)
|
114 |
+
edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0)
|
115 |
+
for layer in self.layers:
|
116 |
+
node_features, edge_features, attn_weights = layer(node_features, edge_features, mask)
|
117 |
+
|
118 |
+
if self.pooling_method == "cls_token":
|
119 |
+
graph_features = node_features[:, 0]
|
120 |
+
else:
|
121 |
+
graph_features = self.pool(node_features, edge_features)
|
122 |
+
self.final_features = (graph_features, node_features, edge_features, attn_weights)
|
123 |
+
return self.proj_out(graph_features)
|
124 |
+
|
125 |
+
|
126 |
+
class RTLayer(nn.Module):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
d_node,
|
130 |
+
d_edge,
|
131 |
+
d_attn_hid,
|
132 |
+
d_node_hid,
|
133 |
+
d_edge_hid,
|
134 |
+
n_heads,
|
135 |
+
dropout,
|
136 |
+
node_update_type="rt",
|
137 |
+
disable_edge_updates=False,
|
138 |
+
modulate_v=True,
|
139 |
+
use_ln=True,
|
140 |
+
tfixit_init=False,
|
141 |
+
n_layers=None,
|
142 |
+
):
|
143 |
+
super().__init__()
|
144 |
+
self.node_update_type = node_update_type
|
145 |
+
self.disable_edge_updates = disable_edge_updates
|
146 |
+
self.use_ln = use_ln
|
147 |
+
self.n_layers = n_layers
|
148 |
+
|
149 |
+
self.self_attn = torch.jit.script(
|
150 |
+
RTAttention(
|
151 |
+
d_node,
|
152 |
+
d_edge,
|
153 |
+
d_attn_hid,
|
154 |
+
n_heads,
|
155 |
+
modulate_v=modulate_v,
|
156 |
+
use_ln=use_ln,
|
157 |
+
)
|
158 |
+
)
|
159 |
+
# self.self_attn = RTAttention(d_hid, d_hid, d_hid, n_heads)
|
160 |
+
self.lin0 = Linear(d_node, d_node)
|
161 |
+
self.dropout0 = nn.Dropout(dropout)
|
162 |
+
if use_ln:
|
163 |
+
self.node_ln0 = nn.LayerNorm(d_node)
|
164 |
+
self.node_ln1 = nn.LayerNorm(d_node)
|
165 |
+
else:
|
166 |
+
self.node_ln0 = nn.Identity()
|
167 |
+
self.node_ln1 = nn.Identity()
|
168 |
+
|
169 |
+
act_fn = nn.GELU
|
170 |
+
|
171 |
+
self.node_mlp = nn.Sequential(
|
172 |
+
Linear(d_node, d_node_hid, bias=False),
|
173 |
+
act_fn(),
|
174 |
+
Linear(d_node_hid, d_node),
|
175 |
+
nn.Dropout(dropout),
|
176 |
+
)
|
177 |
+
|
178 |
+
if not self.disable_edge_updates:
|
179 |
+
self.edge_updates = EdgeLayer(
|
180 |
+
d_node=d_node,
|
181 |
+
d_edge=d_edge,
|
182 |
+
d_edge_hid=d_edge_hid,
|
183 |
+
dropout=dropout,
|
184 |
+
act_fn=act_fn,
|
185 |
+
use_ln=use_ln,
|
186 |
+
)
|
187 |
+
else:
|
188 |
+
self.edge_updates = NoEdgeLayer()
|
189 |
+
|
190 |
+
if tfixit_init:
|
191 |
+
self.fixit_init()
|
192 |
+
|
193 |
+
def fixit_init(self):
|
194 |
+
temp_state_dict = self.state_dict()
|
195 |
+
n_layers = self.n_layers
|
196 |
+
for name, param in self.named_parameters():
|
197 |
+
if "weight" in name:
|
198 |
+
if name.split(".")[0] in ["node_mlp", "edge_mlp0", "edge_mlp1"]:
|
199 |
+
temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * param
|
200 |
+
elif name.split(".")[0] in ["self_attn"]:
|
201 |
+
temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * (
|
202 |
+
param * (2**0.5)
|
203 |
+
)
|
204 |
+
|
205 |
+
self.load_state_dict(temp_state_dict)
|
206 |
+
|
207 |
+
def node_updates(self, node_features, edge_features, mask):
|
208 |
+
out = self.self_attn(node_features, edge_features, mask)
|
209 |
+
attn_out, attn_weights = out
|
210 |
+
node_features = self.node_ln0(
|
211 |
+
node_features
|
212 |
+
+ self.dropout0(
|
213 |
+
self.lin0(attn_out)
|
214 |
+
)
|
215 |
+
)
|
216 |
+
node_features = self.node_ln1(node_features + self.node_mlp(node_features))
|
217 |
+
|
218 |
+
return node_features, attn_weights
|
219 |
+
|
220 |
+
def forward(self, node_features, edge_features, mask):
|
221 |
+
node_features, attn_weights = self.node_updates(node_features, edge_features, mask)
|
222 |
+
edge_features = self.edge_updates(node_features, edge_features, mask)
|
223 |
+
|
224 |
+
return node_features, edge_features, attn_weights
|
225 |
+
|
226 |
+
|
227 |
+
class EdgeLayer(nn.Module):
|
228 |
+
def __init__(
|
229 |
+
self,
|
230 |
+
*,
|
231 |
+
d_node,
|
232 |
+
d_edge,
|
233 |
+
d_edge_hid,
|
234 |
+
dropout,
|
235 |
+
act_fn,
|
236 |
+
use_ln=True,
|
237 |
+
) -> None:
|
238 |
+
super().__init__()
|
239 |
+
self.edge_mlp0 = EdgeMLP(
|
240 |
+
d_edge=d_edge,
|
241 |
+
d_node=d_node,
|
242 |
+
d_edge_hid=d_edge_hid,
|
243 |
+
act_fn=act_fn,
|
244 |
+
dropout=dropout,
|
245 |
+
)
|
246 |
+
self.edge_mlp1 = nn.Sequential(
|
247 |
+
Linear(d_edge, d_edge_hid, bias=False),
|
248 |
+
act_fn(),
|
249 |
+
Linear(d_edge_hid, d_edge),
|
250 |
+
nn.Dropout(dropout),
|
251 |
+
)
|
252 |
+
if use_ln:
|
253 |
+
self.eln0 = nn.LayerNorm(d_edge)
|
254 |
+
self.eln1 = nn.LayerNorm(d_edge)
|
255 |
+
else:
|
256 |
+
self.eln0 = nn.Identity()
|
257 |
+
self.eln1 = nn.Identity()
|
258 |
+
|
259 |
+
def forward(self, node_features, edge_features, mask):
|
260 |
+
edge_features = self.eln0(
|
261 |
+
edge_features + self.edge_mlp0(node_features, edge_features)
|
262 |
+
)
|
263 |
+
edge_features = self.eln1(edge_features + self.edge_mlp1(edge_features))
|
264 |
+
return edge_features
|
265 |
+
|
266 |
+
|
267 |
+
class NoEdgeLayer(nn.Module):
|
268 |
+
def forward(self, node_features, edge_features, mask):
|
269 |
+
return edge_features
|
270 |
+
|
271 |
+
|
272 |
+
class EdgeMLP(nn.Module):
|
273 |
+
def __init__(self, *, d_node, d_edge, d_edge_hid, act_fn, dropout):
|
274 |
+
super().__init__()
|
275 |
+
self.reverse_edge = Rearrange("b n m d -> b m n d")
|
276 |
+
self.lin0_e = Linear(2 * d_edge, d_edge_hid)
|
277 |
+
self.lin0_s = Linear(d_node, d_edge_hid)
|
278 |
+
self.lin0_t = Linear(d_node, d_edge_hid)
|
279 |
+
self.act = act_fn()
|
280 |
+
self.lin1 = Linear(d_edge_hid, d_edge)
|
281 |
+
self.drop = nn.Dropout(dropout)
|
282 |
+
|
283 |
+
def forward(self, node_features, edge_features):
|
284 |
+
source_nodes = (
|
285 |
+
self.lin0_s(node_features)
|
286 |
+
.unsqueeze(-2)
|
287 |
+
.expand(-1, -1, node_features.size(-2), -1)
|
288 |
+
)
|
289 |
+
target_nodes = (
|
290 |
+
self.lin0_t(node_features)
|
291 |
+
.unsqueeze(-3)
|
292 |
+
.expand(-1, node_features.size(-2), -1, -1)
|
293 |
+
)
|
294 |
+
|
295 |
+
# reversed_edge_features = self.reverse_edge(edge_features)
|
296 |
+
edge_features = self.lin0_e(
|
297 |
+
torch.cat([edge_features, self.reverse_edge(edge_features)], dim=-1)
|
298 |
+
)
|
299 |
+
edge_features = edge_features + source_nodes + target_nodes
|
300 |
+
edge_features = self.act(edge_features)
|
301 |
+
edge_features = self.lin1(edge_features)
|
302 |
+
edge_features = self.drop(edge_features)
|
303 |
+
|
304 |
+
return edge_features
|
305 |
+
|
306 |
+
|
307 |
+
class RTAttention(nn.Module):
|
308 |
+
def __init__(self, d_node, d_edge, d_hid, n_heads, modulate_v=None, use_ln=True):
|
309 |
+
super().__init__()
|
310 |
+
self.n_heads = n_heads
|
311 |
+
self.d_node = d_node
|
312 |
+
self.d_edge = d_edge
|
313 |
+
self.d_hid = d_hid
|
314 |
+
self.use_ln = use_ln
|
315 |
+
self.modulate_v = modulate_v
|
316 |
+
self.scale = 1 / (d_hid**0.5)
|
317 |
+
self.split_head_node = Rearrange("b n (h d) -> b h n d", h=n_heads)
|
318 |
+
self.split_head_edge = Rearrange("b n m (h d) -> b h n m d", h=n_heads)
|
319 |
+
self.cat_head_node = Rearrange("... h n d -> ... n (h d)", h=n_heads)
|
320 |
+
|
321 |
+
self.qkv_node = Linear(d_node, 3 * d_hid, bias=False)
|
322 |
+
self.edge_factor = 4 if modulate_v else 3
|
323 |
+
self.qkv_edge = Linear(d_edge, self.edge_factor * d_hid, bias=False)
|
324 |
+
self.proj_out = Linear(d_hid, d_node)
|
325 |
+
|
326 |
+
def forward(self, node_features, edge_features, mask):
|
327 |
+
qkv_node = self.qkv_node(node_features)
|
328 |
+
# qkv_node = rearrange(qkv_node, "b n (h d) -> b h n d", h=self.n_heads)
|
329 |
+
qkv_node = self.split_head_node(qkv_node)
|
330 |
+
q_node, k_node, v_node = torch.chunk(qkv_node, 3, dim=-1)
|
331 |
+
|
332 |
+
qkv_edge = self.qkv_edge(edge_features)
|
333 |
+
# qkv_edge = rearrange(qkv_edge, "b n m (h d) -> b h n m d", h=self.n_heads)
|
334 |
+
qkv_edge = self.split_head_edge(qkv_edge)
|
335 |
+
qkv_edge = torch.chunk(qkv_edge, self.edge_factor, dim=-1)
|
336 |
+
# q_edge, k_edge, v_edge, q_edge_b, k_edge_b, v_edge_b = torch.chunk(
|
337 |
+
# qkv_edge, 6, dim=-1
|
338 |
+
# )
|
339 |
+
# qkv_edge = [item.masked_fill(mask.unsqueeze(1) == 0, 0) for item in qkv_edge]
|
340 |
+
|
341 |
+
q = q_node.unsqueeze(-2) + qkv_edge[0] # + q_edge_b
|
342 |
+
k = k_node.unsqueeze(-3) + qkv_edge[1] # + k_edge_b
|
343 |
+
if self.modulate_v:
|
344 |
+
v = v_node.unsqueeze(-3) * qkv_edge[3] + qkv_edge[2]
|
345 |
+
else:
|
346 |
+
v = v_node.unsqueeze(-3) + qkv_edge[2]
|
347 |
+
dots = self.scale * torch.einsum("b h i j d, b h i j d -> b h i j", q, k)
|
348 |
+
# dots.masked_fill_(mask.unsqueeze(1).squeeze(-1) == 0, -1e-9)
|
349 |
+
|
350 |
+
attn = F.softmax(dots, dim=-1)
|
351 |
+
out = torch.einsum("b h i j, b h i j d -> b h i d", attn, v)
|
352 |
+
out = self.cat_head_node(out)
|
353 |
+
return self.proj_out(out), attn
|
354 |
+
|
355 |
+
|
356 |
+
def Linear(in_features, out_features, bias=True):
|
357 |
+
m = nn.Linear(in_features, out_features, bias)
|
358 |
+
nn.init.xavier_uniform_(m.weight) # , gain=1 / math.sqrt(2))
|
359 |
+
if bias:
|
360 |
+
nn.init.constant_(m.bias, 0.0)
|
361 |
+
return m
|
tasks/utils/train.py
CHANGED
@@ -9,6 +9,14 @@ import glob
|
|
9 |
from collections import OrderedDict
|
10 |
from tqdm import tqdm
|
11 |
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
class Trainer(object):
|
14 |
"""
|
@@ -217,9 +225,12 @@ class Trainer(object):
|
|
217 |
return train_loss, all_accs/total
|
218 |
|
219 |
def train_batch(self, batch, batch_idx, device):
|
220 |
-
x, fft, y = batch['audio']['array'], batch['audio']['
|
|
|
|
|
221 |
x = x.to(device).float()
|
222 |
fft = fft.to(device).float()
|
|
|
223 |
y = y.to(device).float()
|
224 |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
225 |
y_pred = self.model(x_fft).squeeze()
|
@@ -255,7 +266,8 @@ class Trainer(object):
|
|
255 |
return val_loss, all_accs/total
|
256 |
|
257 |
def eval_batch(self, batch, batch_idx, device):
|
258 |
-
x, fft, y = batch['audio']['array'], batch['audio']['
|
|
|
259 |
x = x.to(device).float()
|
260 |
fft = fft.to(device).float()
|
261 |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
@@ -279,7 +291,7 @@ class Trainer(object):
|
|
279 |
true_labels = []
|
280 |
pbar = tqdm(test_dataloader)
|
281 |
for i,batch in enumerate(pbar):
|
282 |
-
x, fft, y = batch['audio']['array'], batch['audio']['
|
283 |
x = x.to(device).float()
|
284 |
fft = fft.to(device).float()
|
285 |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
@@ -297,4 +309,411 @@ class Trainer(object):
|
|
297 |
pbar.set_description("acc: {:.4f}".format(acc))
|
298 |
if i > self.max_iter:
|
299 |
break
|
300 |
-
return predictions, true_labels, all_accs/total
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from collections import OrderedDict
|
10 |
from tqdm import tqdm
|
11 |
import torch.distributed as dist
|
12 |
+
import pandas as pd
|
13 |
+
import xgboost as xgb
|
14 |
+
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
|
15 |
+
|
16 |
+
|
17 |
+
from torch.nn import ModuleList
|
18 |
+
# from inr import INR
|
19 |
+
# from kan import FasterKAN
|
20 |
|
21 |
class Trainer(object):
|
22 |
"""
|
|
|
225 |
return train_loss, all_accs/total
|
226 |
|
227 |
def train_batch(self, batch, batch_idx, device):
|
228 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label']
|
229 |
+
# features = batch['audio']['features_arr'].to(device).float()
|
230 |
+
# cwt = batch['audio']['cwt_mag']
|
231 |
x = x.to(device).float()
|
232 |
fft = fft.to(device).float()
|
233 |
+
# cwt = cwt.to(device).float()
|
234 |
y = y.to(device).float()
|
235 |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
236 |
y_pred = self.model(x_fft).squeeze()
|
|
|
266 |
return val_loss, all_accs/total
|
267 |
|
268 |
def eval_batch(self, batch, batch_idx, device):
|
269 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label']
|
270 |
+
# features = batch['audio']['features_arr'].to(device).float()
|
271 |
x = x.to(device).float()
|
272 |
fft = fft.to(device).float()
|
273 |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
|
|
291 |
true_labels = []
|
292 |
pbar = tqdm(test_dataloader)
|
293 |
for i,batch in enumerate(pbar):
|
294 |
+
x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label']
|
295 |
x = x.to(device).float()
|
296 |
fft = fft.to(device).float()
|
297 |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
|
|
|
309 |
pbar.set_description("acc: {:.4f}".format(acc))
|
310 |
if i > self.max_iter:
|
311 |
break
|
312 |
+
return predictions, true_labels, all_accs/total
|
313 |
+
|
314 |
+
|
315 |
+
class INRDatabase:
|
316 |
+
"""Database to store and manage INRs persistently."""
|
317 |
+
|
318 |
+
def __init__(self, save_dir='./inr_database'):
|
319 |
+
self.inrs = {} # Maps sample_id -> INR
|
320 |
+
self.optimizers = {} # Maps sample_id -> optimizer state
|
321 |
+
self.save_dir = save_dir
|
322 |
+
os.makedirs(save_dir, exist_ok=True)
|
323 |
+
|
324 |
+
def get_or_create_inr(self, sample_id, create_fn, device):
|
325 |
+
"""Get existing INR or create new one if not exists."""
|
326 |
+
if sample_id not in self.inrs:
|
327 |
+
# Create new INR
|
328 |
+
inr = create_fn().to(device)
|
329 |
+
optimizer = torch.optim.Adam(inr.parameters())
|
330 |
+
self.inrs[sample_id] = inr
|
331 |
+
self.optimizers[sample_id] = optimizer
|
332 |
+
return self.inrs[sample_id], self.optimizers[sample_id]
|
333 |
+
|
334 |
+
def set_inr(self, sample_id, inr, optimizer):
|
335 |
+
self.inrs[sample_id] = inr
|
336 |
+
self.optimizers[sample_id] = optimizer
|
337 |
+
|
338 |
+
def save_state(self):
|
339 |
+
"""Save all INRs and optimizer states to disk."""
|
340 |
+
state = {
|
341 |
+
'inrs': {
|
342 |
+
sample_id: inr.state_dict()
|
343 |
+
for sample_id, inr in self.inrs.items()
|
344 |
+
},
|
345 |
+
'optimizers': {
|
346 |
+
sample_id: opt.state_dict()
|
347 |
+
for sample_id, opt in self.optimizers.items()
|
348 |
+
}
|
349 |
+
}
|
350 |
+
torch.save(state, os.path.join(self.save_dir, 'inr_database.pt'))
|
351 |
+
|
352 |
+
def load_state(self, create_fn, device):
|
353 |
+
"""Load INRs and optimizer states from disk."""
|
354 |
+
path = os.path.join(self.save_dir, 'inr_database.pt')
|
355 |
+
if os.path.exists(path):
|
356 |
+
state = torch.load(path, map_location=device)
|
357 |
+
|
358 |
+
# Restore INRs
|
359 |
+
for sample_id, inr_state in state['inrs'].items():
|
360 |
+
inr = create_fn().to(device)
|
361 |
+
inr.load_state_dict(inr_state)
|
362 |
+
self.inrs[sample_id] = inr
|
363 |
+
|
364 |
+
# Restore optimizers
|
365 |
+
for sample_id, opt_state in state['optimizers'].items():
|
366 |
+
optimizer = torch.optim.Adam(self.inrs[sample_id].parameters())
|
367 |
+
optimizer.load_state_dict(opt_state)
|
368 |
+
self.optimizers[sample_id] = optimizer
|
369 |
+
|
370 |
+
|
371 |
+
class INRTrainer(Trainer):
|
372 |
+
def __init__(self, hidden_features=128, n_layers=3, in_features=1, out_features=1,
|
373 |
+
num_steps=5000, lr=1e-3, inr_criterion=torch.nn.MSELoss(), save_dir='./inr_database', *args, **kwargs):
|
374 |
+
super().__init__(*args, **kwargs)
|
375 |
+
self.hidden_features = hidden_features
|
376 |
+
self.n_layers = n_layers
|
377 |
+
self.in_features = in_features
|
378 |
+
self.out_features = out_features
|
379 |
+
self.num_steps = num_steps
|
380 |
+
self.lr = lr
|
381 |
+
self.inr_criterion = inr_criterion
|
382 |
+
|
383 |
+
# Initialize INR database
|
384 |
+
self.db = INRDatabase(save_dir)
|
385 |
+
|
386 |
+
# Load existing INRs if available
|
387 |
+
self.db.load_state(self.create_inr, self.device)
|
388 |
+
|
389 |
+
def create_inr(self):
|
390 |
+
"""Factory function to create new INR instances."""
|
391 |
+
return INR(
|
392 |
+
hidden_features=self.hidden_features,
|
393 |
+
n_layers=self.n_layers,
|
394 |
+
in_features=self.in_features,
|
395 |
+
out_features=self.out_features
|
396 |
+
)
|
397 |
+
|
398 |
+
def create_kan(self):
|
399 |
+
return FasterKAN(layers_hidden=[self.in_features] + [self.hidden_features] * (self.n_layers) + [self.out_features],)
|
400 |
+
|
401 |
+
def get_sample_id(self, batch, idx):
|
402 |
+
"""Extract unique identifier for a sample in the batch.
|
403 |
+
Override this method based on your data structure."""
|
404 |
+
# Example: if your batch contains unique IDs
|
405 |
+
if 'id' in batch:
|
406 |
+
return batch['id'][idx]
|
407 |
+
# Fallback: create hash from the sample data
|
408 |
+
sample_data = batch['audio']['array'][idx]
|
409 |
+
return hash(sample_data.cpu().numpy().tobytes())
|
410 |
+
|
411 |
+
def train_inr(self, optimizer, model, coords, values, num_iters=10, plot=False):
|
412 |
+
# pbar = tqdm(range(num_iters))
|
413 |
+
for _ in range(num_iters):
|
414 |
+
optimizer.zero_grad()
|
415 |
+
pred_values = model(coords.to(self.device)).float()
|
416 |
+
loss = self.inr_criterion(pred_values.squeeze(), values)
|
417 |
+
loss.backward()
|
418 |
+
optimizer.step()
|
419 |
+
# pbar.set_description(f'loss: {loss.item()}')
|
420 |
+
if plot:
|
421 |
+
plt.plot(values.cpu().detach().numpy())
|
422 |
+
plt.plot(pred_values.cpu().detach().numpy())
|
423 |
+
plt.title(loss.item())
|
424 |
+
plt.show()
|
425 |
+
return model, optimizer
|
426 |
+
|
427 |
+
def train_batch(self, batch, batch_idx, device):
|
428 |
+
"""Train INRs for each sample in batch, persisting progress."""
|
429 |
+
coords = batch['audio']['coords'].to(device) # [B, N, 1]
|
430 |
+
fft = batch['audio']['fft_mag'].to(device) # [B, N]
|
431 |
+
audio = batch['audio']['array'].to(device) # [B, N]
|
432 |
+
y = batch['label'].to(device).float()
|
433 |
+
|
434 |
+
batch_size = coords.shape[0]
|
435 |
+
|
436 |
+
values = audio
|
437 |
+
|
438 |
+
batch_losses = []
|
439 |
+
batch_optimizers = []
|
440 |
+
batch_inrs = []
|
441 |
+
batch_weights = tuple()
|
442 |
+
batch_biases = tuple()
|
443 |
+
# Training loop
|
444 |
+
# pbar = tqdm(range(self.num_steps), desc="Training INRs")
|
445 |
+
plot = batch_idx == 0
|
446 |
+
for i in range(batch_size):
|
447 |
+
sample_id = self.get_sample_id(batch, i)
|
448 |
+
inr, optimizer = self.db.get_or_create_inr(sample_id, self.create_inr, device)
|
449 |
+
inr, optimizer = self.train_inr(optimizer, inr, coords[i], values[i])
|
450 |
+
self.db.set_inr(sample_id, inr, optimizer)
|
451 |
+
# pred_values = inr(coords[i]).squeeze()
|
452 |
+
# batch_losses.append(self.inr_criterion(pred_values, values[i]))
|
453 |
+
# batch_optimizers.append(optimizer)
|
454 |
+
state_dict = inr.state_dict()
|
455 |
+
weights = tuple(
|
456 |
+
[v.permute(1, 0).unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "weight" in w]
|
457 |
+
)
|
458 |
+
biases = tuple([v.unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "bias" in w])
|
459 |
+
if not len(batch_weights):
|
460 |
+
batch_weights = weights
|
461 |
+
else:
|
462 |
+
batch_weights = tuple(
|
463 |
+
[torch.cat((weights[i], batch_weights[i]), dim=0) for i in range(len(weights))]
|
464 |
+
)
|
465 |
+
if not len(batch_biases):
|
466 |
+
batch_biases = biases
|
467 |
+
else:
|
468 |
+
batch_biases = tuple(
|
469 |
+
[torch.cat((biases[i], batch_biases[i]), dim=0) for i in range(len(biases))]
|
470 |
+
)
|
471 |
+
# loss_preds = torch.tensor([0])
|
472 |
+
# acc = 0
|
473 |
+
y_pred = self.model(inputs=(batch_weights, batch_biases)).squeeze()
|
474 |
+
loss_preds = self.criterion(y_pred, y)
|
475 |
+
self.optimizer.zero_grad()
|
476 |
+
loss_preds.backward()
|
477 |
+
self.optimizer.step()
|
478 |
+
# for i in range(batch_size):
|
479 |
+
# batch_optimizers[i].zero_grad()
|
480 |
+
# batch_losses[i] += loss_preds
|
481 |
+
# batch_losses[i].backward()
|
482 |
+
# batch_optimizers[i].step()
|
483 |
+
|
484 |
+
|
485 |
+
if batch_idx % 10 == 0: # Adjust frequency as needed
|
486 |
+
self.db.save_state()
|
487 |
+
|
488 |
+
probs = torch.sigmoid(y_pred)
|
489 |
+
cls_pred = (probs > 0.5).float()
|
490 |
+
acc = (cls_pred == y).sum()
|
491 |
+
|
492 |
+
|
493 |
+
return loss_preds, acc, y
|
494 |
+
|
495 |
+
def eval_batch(self, batch, batch_idx, device):
|
496 |
+
"""Evaluate INRs for each sample in batch."""
|
497 |
+
coords = batch['audio']['coords'].to(device)
|
498 |
+
fft = batch['audio']['fft_mag'].to(device)
|
499 |
+
audio = batch['audio']['array'].to(device)
|
500 |
+
|
501 |
+
batch_size = coords.shape[0]
|
502 |
+
# values = torch.cat((
|
503 |
+
# audio.unsqueeze(-1),
|
504 |
+
# fft.unsqueeze(-1)
|
505 |
+
# ), dim=-1)
|
506 |
+
values = audio
|
507 |
+
# Get INRs for each sample
|
508 |
+
batch_inrs = []
|
509 |
+
for i in range(batch_size):
|
510 |
+
sample_id = self.get_sample_id(batch, i)
|
511 |
+
inr, _ = self.db.get_or_create_inr(sample_id, self.create_inr, device)
|
512 |
+
batch_inrs.append(inr)
|
513 |
+
|
514 |
+
# Evaluate
|
515 |
+
with torch.no_grad():
|
516 |
+
all_preds = torch.stack([
|
517 |
+
inr(coords[i])
|
518 |
+
for i, inr in enumerate(batch_inrs)
|
519 |
+
])
|
520 |
+
|
521 |
+
batch_losses = torch.stack([
|
522 |
+
self.criterion(all_preds[i].squeeze(), values[i])
|
523 |
+
for i in range(batch_size)
|
524 |
+
])
|
525 |
+
|
526 |
+
avg_loss = batch_losses.mean().item()
|
527 |
+
|
528 |
+
acc = torch.zeros(self.output_dim, device=device)
|
529 |
+
y = values
|
530 |
+
|
531 |
+
return torch.tensor(avg_loss), acc, y
|
532 |
+
|
533 |
+
|
534 |
+
def verify_parallel_gradient_isolation(trainer, batch_size=4, sequence_length=1000):
|
535 |
+
"""
|
536 |
+
Verify that gradients remain isolated in parallel training.
|
537 |
+
"""
|
538 |
+
device = trainer.device
|
539 |
+
|
540 |
+
# Create test data
|
541 |
+
coords = torch.linspace(0, 1, sequence_length).unsqueeze(-1) # [N, 1]
|
542 |
+
coords = coords.unsqueeze(0).repeat(batch_size, 1, 1) # [B, N, 1]
|
543 |
+
|
544 |
+
# Create synthetic signals
|
545 |
+
targets = torch.stack([
|
546 |
+
torch.sin(2 * torch.pi * (i + 1) * coords.squeeze(-1))
|
547 |
+
for i in range(batch_size)
|
548 |
+
]).to(device)
|
549 |
+
|
550 |
+
# Create batch of INRs
|
551 |
+
inrs = trainer.create_batch_inrs()
|
552 |
+
|
553 |
+
# Store initial parameters
|
554 |
+
initial_params = [{name: param.clone().detach()
|
555 |
+
for name, param in inr.named_parameters()}
|
556 |
+
for inr in inrs]
|
557 |
+
|
558 |
+
# Create mock batch
|
559 |
+
batch = {
|
560 |
+
'audio': {
|
561 |
+
'coords': coords.to(device),
|
562 |
+
'fft_mag': targets.clone(),
|
563 |
+
'array': targets.clone()
|
564 |
+
}
|
565 |
+
}
|
566 |
+
|
567 |
+
# Run one training step
|
568 |
+
trainer.train_batch(batch, 0, device)
|
569 |
+
|
570 |
+
# Verify parameter changes
|
571 |
+
isolation_verified = True
|
572 |
+
for i, inr in enumerate(inrs):
|
573 |
+
params_changed = False
|
574 |
+
for name, param in inr.named_parameters():
|
575 |
+
if not torch.allclose(param, initial_params[i][name]):
|
576 |
+
params_changed = True
|
577 |
+
# Verify that the changes are unique to this INR
|
578 |
+
for j, other_inr in enumerate(inrs):
|
579 |
+
if i != j:
|
580 |
+
other_param = dict(other_inr.named_parameters())[name]
|
581 |
+
if not torch.allclose(other_param, initial_params[j][name]):
|
582 |
+
isolation_verified = False
|
583 |
+
print(f"Warning: Parameter {name} of INR {j} changed when only INR {i} should have changed")
|
584 |
+
|
585 |
+
return isolation_verified
|
586 |
+
|
587 |
+
class XGBoostTrainer():
|
588 |
+
def __init__(self, model_args, train_ds, val_ds, test_ds):
|
589 |
+
self.train_ds = train_ds
|
590 |
+
self.test_ds = test_ds
|
591 |
+
print("creating train dataframe...")
|
592 |
+
self.x_train, self.y_train = self.create_dataframe(train_ds, save_name='train')
|
593 |
+
print("creating validation dataframe...")
|
594 |
+
self.x_val, self.y_val = self.create_dataframe(val_ds, save_name='val')
|
595 |
+
print("creating test dataframe...")
|
596 |
+
self.x_test, self.y_test = self.create_dataframe(test_ds, save_name='test')
|
597 |
+
|
598 |
+
# Convert the data to DMatrix format
|
599 |
+
self.dtrain = xgb.DMatrix(self.x_train, label=self.y_train)
|
600 |
+
self.dval = xgb.DMatrix(self.x_val, label=self.y_val)
|
601 |
+
self.dtest = xgb.DMatrix(self.x_test, label=self.y_test)
|
602 |
+
|
603 |
+
# Model initialization
|
604 |
+
self.model_args = model_args
|
605 |
+
self.model = xgb.XGBClassifier(**model_args)
|
606 |
+
|
607 |
+
def create_dataframe(self, ds, save_name='train'):
|
608 |
+
try:
|
609 |
+
df = pd.read_csv(f"tasks/utils/dfs/{save_name}.csv")
|
610 |
+
except FileNotFoundError:
|
611 |
+
data = []
|
612 |
+
|
613 |
+
# Iterate over the dataset
|
614 |
+
pbar = tqdm(enumerate(ds))
|
615 |
+
for i, batch in pbar:
|
616 |
+
label = batch['label']
|
617 |
+
features = batch['audio']['features']
|
618 |
+
|
619 |
+
# Flatten the nested dictionary structure
|
620 |
+
feature_dict = {'label': label}
|
621 |
+
for k, v in features.items():
|
622 |
+
if isinstance(v, dict):
|
623 |
+
for sub_k, sub_v in v.items():
|
624 |
+
feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean)
|
625 |
+
data.append(feature_dict)
|
626 |
+
# Convert to DataFrame
|
627 |
+
df = pd.DataFrame(data)
|
628 |
+
print(os.getcwd())
|
629 |
+
df.to_csv(f"tasks/utils/dfs/{save_name}.csv", index=False)
|
630 |
+
X = df.drop(columns=['label'])
|
631 |
+
y = df['label']
|
632 |
+
return X, y
|
633 |
+
|
634 |
+
def fit(self):
|
635 |
+
# Train using the `train` method with early stopping
|
636 |
+
params = {
|
637 |
+
'objective': 'binary:logistic',
|
638 |
+
'eval_metric': 'logloss',
|
639 |
+
**self.model_args
|
640 |
+
}
|
641 |
+
|
642 |
+
evals_result = {}
|
643 |
+
num_boost_round = 1000 # Set a large number of boosting rounds
|
644 |
+
|
645 |
+
# Watchlist to monitor performance on train and validation data
|
646 |
+
watchlist = [(self.dtrain, 'train'), (self.dval, 'eval')]
|
647 |
+
|
648 |
+
# Train the model
|
649 |
+
self.model = xgb.train(
|
650 |
+
params,
|
651 |
+
self.dtrain,
|
652 |
+
num_boost_round=num_boost_round,
|
653 |
+
evals=watchlist,
|
654 |
+
early_stopping_rounds=10, # Early stopping after 10 rounds with no improvement
|
655 |
+
evals_result=evals_result,
|
656 |
+
verbose_eval=True # Show evaluation results for each iteration
|
657 |
+
)
|
658 |
+
|
659 |
+
return evals_result
|
660 |
+
|
661 |
+
def train_xgboost_in_batches(self, dataloader, eval_metric="logloss"):
|
662 |
+
evals_result = {}
|
663 |
+
for i, batch in enumerate(dataloader):
|
664 |
+
# Convert batch data to NumPy arrays
|
665 |
+
X_batch = torch.cat([batch[key].view(batch[key].size(0), -1) for key in batch if key != "label"],
|
666 |
+
dim=1).numpy()
|
667 |
+
y_batch = batch["label"].numpy()
|
668 |
+
|
669 |
+
# Create DMatrix for XGBoost
|
670 |
+
dtrain = xgb.DMatrix(X_batch, label=y_batch)
|
671 |
+
|
672 |
+
# Use `train` with each batch
|
673 |
+
self.model = xgb.train(
|
674 |
+
params,
|
675 |
+
dtrain,
|
676 |
+
num_boost_round=1000, # Use a large number of rounds
|
677 |
+
evals=[(self.dval, 'eval')],
|
678 |
+
eval_metric=eval_metric,
|
679 |
+
early_stopping_rounds=10,
|
680 |
+
evals_result=evals_result,
|
681 |
+
verbose_eval=False # Avoid printing every iteration
|
682 |
+
)
|
683 |
+
|
684 |
+
# Optionally print progress
|
685 |
+
if i % 10 == 0:
|
686 |
+
print(f"Batch {i + 1}/{len(dataloader)} processed.")
|
687 |
+
|
688 |
+
return evals_result
|
689 |
+
|
690 |
+
def predict(self):
|
691 |
+
# Predict probabilities for class 1
|
692 |
+
y_prob = self.model.predict(self.dtest, output_margin=False)
|
693 |
+
|
694 |
+
# Convert probabilities to binary labels (0 or 1) using a threshold (e.g., 0.5)
|
695 |
+
y_pred = (y_prob >= 0.5).astype(int)
|
696 |
+
|
697 |
+
# Evaluate performance
|
698 |
+
accuracy = accuracy_score(self.y_test, y_pred)
|
699 |
+
roc_auc = roc_auc_score(self.y_test, y_prob)
|
700 |
+
|
701 |
+
print(f'Accuracy: {accuracy:.4f}')
|
702 |
+
print(f'ROC AUC Score: {roc_auc:.4f}')
|
703 |
+
print(classification_report(self.y_test, y_pred))
|
704 |
+
|
705 |
+
def plot_results(self, evals_result):
|
706 |
+
train_logloss = evals_result["train"]["logloss"]
|
707 |
+
val_logloss = evals_result["eval"]["logloss"]
|
708 |
+
iterations = range(1, len(train_logloss) + 1)
|
709 |
+
|
710 |
+
# Plot
|
711 |
+
plt.figure(figsize=(8, 5))
|
712 |
+
plt.plot(iterations, train_logloss, label="Train LogLoss", color="blue")
|
713 |
+
plt.plot(iterations, val_logloss, label="Validation LogLoss", color="red")
|
714 |
+
plt.xlabel("Boosting Round (Iteration)")
|
715 |
+
plt.ylabel("Log Loss")
|
716 |
+
plt.title("Training and Validation Log Loss over Iterations")
|
717 |
+
plt.legend()
|
718 |
+
plt.grid()
|
719 |
+
plt.show()
|
tasks/utils/transforms.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import librosa
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import pywt
|
6 |
+
from scipy import signal
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def compute_cwt_power_spectrum(audio, sample_rate, num_freqs=128, f_min=20, f_max=None):
|
11 |
+
"""
|
12 |
+
Compute the power spectrum of continuous wavelet transform using Morlet wavelet.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
audio: torch.Tensor
|
16 |
+
Input audio signal
|
17 |
+
sample_rate: int
|
18 |
+
Sampling rate of the audio
|
19 |
+
num_freqs: int
|
20 |
+
Number of frequency bins for the CWT
|
21 |
+
f_min: float
|
22 |
+
Minimum frequency to analyze
|
23 |
+
f_max: float or None
|
24 |
+
Maximum frequency to analyze (defaults to Nyquist frequency)
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: CWT power spectrum
|
28 |
+
"""
|
29 |
+
# Convert to numpy
|
30 |
+
audio_np = audio.cpu().numpy()
|
31 |
+
|
32 |
+
# Set default f_max to Nyquist frequency if not specified
|
33 |
+
if f_max is None:
|
34 |
+
f_max = sample_rate // 2
|
35 |
+
|
36 |
+
# Generate frequency bins (logarithmically spaced)
|
37 |
+
frequencies = np.logspace(
|
38 |
+
np.log10(f_min),
|
39 |
+
np.log10(f_max),
|
40 |
+
num=num_freqs
|
41 |
+
)
|
42 |
+
|
43 |
+
# Compute the width of the wavelet (in samples)
|
44 |
+
widths = sample_rate / (2 * frequencies * np.pi)
|
45 |
+
|
46 |
+
# Compute CWT using Morlet wavelet
|
47 |
+
cwt = signal.cwt(
|
48 |
+
audio_np,
|
49 |
+
signal.morlet2,
|
50 |
+
widths,
|
51 |
+
w=5.0 # Width parameter of Morlet wavelet
|
52 |
+
)
|
53 |
+
|
54 |
+
# Compute power spectrum (magnitude squared)
|
55 |
+
power_spectrum = np.abs(cwt) ** 2
|
56 |
+
|
57 |
+
# Convert to torch tensor
|
58 |
+
power_spectrum_tensor = torch.FloatTensor(power_spectrum)
|
59 |
+
|
60 |
+
return power_spectrum_tensor
|
61 |
+
|
62 |
+
def compute_wavelet_transform(audio, wavelet, decompos_level):
|
63 |
+
"""Compute wavelet decomposition of the audio signal."""
|
64 |
+
# Convert to numpy and ensure 1D
|
65 |
+
audio_np = audio.cpu().numpy()
|
66 |
+
|
67 |
+
# Perform wavelet decomposition
|
68 |
+
coeffs = pywt.wavedec(audio_np, wavelet, level=decompos_level)
|
69 |
+
|
70 |
+
# Stack coefficients into a 2D array
|
71 |
+
# First, pad all coefficient arrays to the same length
|
72 |
+
max_len = max(len(c) for c in coeffs)
|
73 |
+
padded_coeffs = []
|
74 |
+
for coeff in coeffs:
|
75 |
+
pad_len = max_len - len(coeff)
|
76 |
+
if pad_len > 0:
|
77 |
+
padded_coeff = np.pad(coeff, (0, pad_len), mode='constant')
|
78 |
+
else:
|
79 |
+
padded_coeff = coeff
|
80 |
+
padded_coeffs.append(padded_coeff)
|
81 |
+
|
82 |
+
# Stack into 2D array where each row is a different scale
|
83 |
+
wavelet_features = np.stack(padded_coeffs)
|
84 |
+
|
85 |
+
# Convert to tensor
|
86 |
+
return torch.FloatTensor(wavelet_features)
|
87 |
+
|
88 |
+
|
89 |
+
def compute_melspectrogram(audio, sample_rate):
|
90 |
+
mel_spec = librosa.feature.melspectrogram(
|
91 |
+
y=audio.cpu().numpy(),
|
92 |
+
sr=sample_rate,
|
93 |
+
n_mels=128
|
94 |
+
)
|
95 |
+
return torch.FloatTensor(librosa.power_to_db(mel_spec))
|
96 |
+
|
97 |
+
|
98 |
+
def compute_mfcc(audio, sample_rate):
|
99 |
+
mfcc = librosa.feature.mfcc(
|
100 |
+
y=audio.cpu().numpy(),
|
101 |
+
sr=sample_rate,
|
102 |
+
n_mfcc=20
|
103 |
+
)
|
104 |
+
return torch.FloatTensor(mfcc)
|
105 |
+
|
106 |
+
|
107 |
+
def compute_chroma(audio, sample_rate):
|
108 |
+
chroma = librosa.feature.chroma_stft(
|
109 |
+
y=audio.cpu().numpy(),
|
110 |
+
sr=sample_rate
|
111 |
+
)
|
112 |
+
return torch.FloatTensor(chroma)
|
113 |
+
|
114 |
+
|
115 |
+
def compute_time_domain_features(audio, sample_rate, frame_length=2048, hop_length=128):
|
116 |
+
"""
|
117 |
+
Compute time-domain features from audio signal.
|
118 |
+
Returns a dictionary of features.
|
119 |
+
"""
|
120 |
+
# Convert to numpy
|
121 |
+
audio_np = audio.cpu().numpy()
|
122 |
+
|
123 |
+
# Initialize dictionary for features
|
124 |
+
features = {}
|
125 |
+
|
126 |
+
# 1. Zero Crossing Rate
|
127 |
+
zcr = librosa.feature.zero_crossing_rate(
|
128 |
+
y=audio_np,
|
129 |
+
frame_length=frame_length,
|
130 |
+
hop_length=hop_length
|
131 |
+
)
|
132 |
+
features['zcr'] = torch.Tensor([zcr.sum()])
|
133 |
+
|
134 |
+
# 2. Root Mean Square Energy
|
135 |
+
rms = librosa.feature.rms(
|
136 |
+
y=audio_np,
|
137 |
+
frame_length=frame_length,
|
138 |
+
hop_length=hop_length
|
139 |
+
)
|
140 |
+
features['rms_energy'] = torch.Tensor([rms.mean()])
|
141 |
+
|
142 |
+
# 3. Temporal Statistics
|
143 |
+
frames = librosa.util.frame(audio_np, frame_length=frame_length, hop_length=hop_length)
|
144 |
+
features['mean'] = torch.Tensor([np.mean(frames, axis=0).mean()])
|
145 |
+
features['std'] = torch.Tensor([np.std(frames, axis=0).mean()])
|
146 |
+
features['max'] = torch.Tensor([np.max(frames, axis=0).mean()])
|
147 |
+
|
148 |
+
# 4. Tempo and Beat Features
|
149 |
+
onset_env = librosa.onset.onset_strength(y=audio_np, sr=sample_rate)
|
150 |
+
tempo = librosa.beat.tempo(onset_envelope=onset_env, sr=sample_rate)
|
151 |
+
features['tempo'] = torch.Tensor(tempo)
|
152 |
+
|
153 |
+
# 5. Amplitude Envelope
|
154 |
+
envelope = np.abs(librosa.stft(audio_np, n_fft=frame_length, hop_length=hop_length))
|
155 |
+
features['envelope'] = torch.Tensor([np.mean(envelope, axis=0).mean()])
|
156 |
+
|
157 |
+
return features
|
158 |
+
|
159 |
+
|
160 |
+
def compute_frequency_domain_features(audio, sample_rate, n_fft=2048, hop_length=512):
|
161 |
+
"""
|
162 |
+
Compute frequency-domain features from audio signal.
|
163 |
+
Returns a dictionary of features.
|
164 |
+
"""
|
165 |
+
# Convert to numpy
|
166 |
+
audio_np = audio.cpu().numpy()
|
167 |
+
|
168 |
+
# Initialize dictionary for features
|
169 |
+
features = {}
|
170 |
+
|
171 |
+
# 1. Spectral Centroid
|
172 |
+
try:
|
173 |
+
spectral_centroids = librosa.feature.spectral_centroid(
|
174 |
+
y=audio_np,
|
175 |
+
sr=sample_rate,
|
176 |
+
n_fft=n_fft,
|
177 |
+
hop_length=hop_length,
|
178 |
+
|
179 |
+
)
|
180 |
+
features['spectral_centroid'] = torch.FloatTensor([spectral_centroids.max()])
|
181 |
+
except Exception as e:
|
182 |
+
features['spectral_centroid'] = torch.FloatTensor([np.nan])
|
183 |
+
|
184 |
+
# 2. Spectral Rolloff
|
185 |
+
try:
|
186 |
+
spectral_rolloff = librosa.feature.spectral_rolloff(
|
187 |
+
y=audio_np,
|
188 |
+
sr=sample_rate,
|
189 |
+
n_fft=n_fft,
|
190 |
+
hop_length=hop_length,
|
191 |
+
|
192 |
+
)
|
193 |
+
features['spectral_rolloff'] = torch.FloatTensor([spectral_rolloff.max()])
|
194 |
+
except Exception as e:
|
195 |
+
features['spectral_rolloff'] = torch.FloatTensor([np.nan])
|
196 |
+
|
197 |
+
# 3. Spectral Bandwidth
|
198 |
+
try:
|
199 |
+
spectral_bandwidth = librosa.feature.spectral_bandwidth(
|
200 |
+
y=audio_np,
|
201 |
+
sr=sample_rate,
|
202 |
+
n_fft=n_fft,
|
203 |
+
hop_length=hop_length
|
204 |
+
)
|
205 |
+
features['spectral_bandwidth'] = torch.FloatTensor([spectral_bandwidth.max()])
|
206 |
+
except Exception as e:
|
207 |
+
features['spectral_bandwidth'] = torch.FloatTensor([np.nan])
|
208 |
+
# 4. Spectral Contrast
|
209 |
+
try:
|
210 |
+
spectral_contrast = librosa.feature.spectral_contrast(
|
211 |
+
y=audio_np,
|
212 |
+
sr=sample_rate,
|
213 |
+
n_fft=n_fft,
|
214 |
+
hop_length=hop_length,
|
215 |
+
fmin=20, # Lower minimum frequency
|
216 |
+
n_bands=4, # Reduce number of bands
|
217 |
+
quantile=0.02
|
218 |
+
)
|
219 |
+
features['spectral_contrast'] = torch.FloatTensor([spectral_contrast.mean()])
|
220 |
+
except Exception as e:
|
221 |
+
features['spectral_contrast'] = torch.FloatTensor([np.nan])
|
222 |
+
|
223 |
+
# 5. Spectral Flatness
|
224 |
+
try:
|
225 |
+
spectral_flatness = librosa.feature.spectral_flatness(
|
226 |
+
y=audio_np,
|
227 |
+
n_fft=n_fft,
|
228 |
+
hop_length=hop_length
|
229 |
+
)
|
230 |
+
features['spectral_flatness'] = torch.FloatTensor([spectral_flatness.max()])
|
231 |
+
except Exception as e:
|
232 |
+
features['spectral_flatness'] = torch.FloatTensor([np.nan])
|
233 |
+
|
234 |
+
# 6. Spectral Flux
|
235 |
+
try:
|
236 |
+
stft = np.abs(librosa.stft(audio_np, n_fft=n_fft, hop_length=hop_length))
|
237 |
+
spectral_flux = np.diff(stft, axis=1)
|
238 |
+
spectral_flux = np.pad(spectral_flux, ((0, 0), (1, 0)), mode='constant')
|
239 |
+
features['spectral_flux'] = torch.FloatTensor([np.std(spectral_flux)])
|
240 |
+
except Exception as e:
|
241 |
+
features['spectral_flux'] = torch.FloatTensor([np.nan])
|
242 |
+
|
243 |
+
return features
|
244 |
+
|
245 |
+
|
246 |
+
def compute_all_features(audio, sample_rate, wavelet='db1', decompos_level=4):
|
247 |
+
"""
|
248 |
+
Compute all available features and return them in a dictionary.
|
249 |
+
"""
|
250 |
+
features = {}
|
251 |
+
|
252 |
+
# Basic transformations
|
253 |
+
# features['wavelet'] = compute_wavelet_transform(audio, wavelet, decompos_level)
|
254 |
+
# features['melspectrogram'] = compute_melspectrogram(audio, sample_rate)
|
255 |
+
# features['mfcc'] = compute_mfcc(audio, sample_rate)
|
256 |
+
# features['chroma'] = compute_chroma(audio, sample_rate)
|
257 |
+
|
258 |
+
# features['cwt_power'] = compute_cwt_power_spectrum(
|
259 |
+
# audio,
|
260 |
+
# sample_rate,
|
261 |
+
# num_freqs=128, # Same as mel bands for consistency
|
262 |
+
# f_min=20, # Standard lower frequency bound
|
263 |
+
# f_max=sample_rate // 2 # Nyquist frequency
|
264 |
+
# )
|
265 |
+
|
266 |
+
# Time domain features
|
267 |
+
# features['time_domain'] = compute_time_domain_features(audio, sample_rate)
|
268 |
+
|
269 |
+
# Frequency domain features
|
270 |
+
features['frequency_domain'] = compute_frequency_domain_features(audio, sample_rate)
|
271 |
+
|
272 |
+
return features
|
test
ADDED
Binary file (70 kB). View file
|
|