IlayMalinyak commited on
Commit
2f54ec8
·
1 Parent(s): 92c056d
tasks/audio.py CHANGED
@@ -29,7 +29,6 @@ DESCRIPTION = "Conformer"
29
  ROUTE = "/audio"
30
 
31
 
32
-
33
  @router.post(ROUTE, tags=["Audio Task"],
34
  description=DESCRIPTION)
35
  async def evaluate_audio(request: AudioEvaluationRequest):
@@ -133,11 +132,11 @@ async def evaluate_audio(request: AudioEvaluationRequest):
133
 
134
  return results
135
 
136
- # if __name__ == "__main__":
137
- # sample_request = AudioEvaluationRequest(
138
- # dataset_name="rfcx/frugalai", # Replace with actual dataset name
139
- # test_size=0.2, # Example values
140
- # test_seed=42
141
- # )
142
- # #
143
- # asyncio.run(evaluate_audio(sample_request))
 
29
  ROUTE = "/audio"
30
 
31
 
 
32
  @router.post(ROUTE, tags=["Audio Task"],
33
  description=DESCRIPTION)
34
  async def evaluate_audio(request: AudioEvaluationRequest):
 
132
 
133
  return results
134
 
135
+ if __name__ == "__main__":
136
+ sample_request = AudioEvaluationRequest(
137
+ dataset_name="rfcx/frugalai", # Replace with actual dataset name
138
+ test_size=0.2, # Example values
139
+ test_seed=42
140
+ )
141
+ #
142
+ asyncio.run(evaluate_audio(sample_request))
tasks/inr_database/inr_database.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13ead5ed23d1fa59062f4872fd784f609575fa7dd4876ea0ef562f9f817801c1
3
+ size 50872350
tasks/models/frugal_2025-01-27/frugal_kan_2.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3f520cff8b9531981e16a8b009b6a55fb8ca98573fc4d3dc6806df60b07a49c2
3
- size 1710980
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58e353129b2750993441ea459485b150b2a45b39cbdb7e49bd1839809e4671e2
3
+ size 1363844
tasks/models/frugal_2025-01-28/frugal_kan_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9591abf1e617bfedd8414f7c031de3394d7e1bccc64763f7060cef0ee13fab65
3
+ size 1710980
tasks/models/frugal_2025-01-29/CNNEncoder_frugal_2.json ADDED
The diff for this file is too large to render. See raw diff
 
tasks/models/frugal_2025-01-29/frugal_kan_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f11c67e65bf6acd1f7e2055dce4a92de0d4d1d88fb12c136eb70198c8ac6eab8
3
+ size 1710980
tasks/run.py CHANGED
@@ -1,14 +1,17 @@
1
  from torch.utils.data import DataLoader
2
  from .utils.data import FFTDataset, SplitDataset
3
  from datasets import load_dataset
4
- from .utils.train import Trainer
5
- from .utils.models import CNNKan, KanEncoder
6
  from .utils.data_utils import *
7
  from huggingface_hub import login
8
  import yaml
9
  import datetime
10
  import json
11
  import numpy as np
 
 
 
12
  from collections import OrderedDict
13
 
14
  # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -19,9 +22,11 @@ data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
19
  exp_num = data_args.exp_num
20
  model_name = data_args.model_name
21
  model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
 
22
  model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
23
  conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
24
  kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
 
25
  if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
26
  os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
27
 
@@ -44,26 +49,62 @@ val_dl = DataLoader(val_ds,batch_size=data_args.batch_size, collate_fn=collate_f
44
  test_ds = FFTDataset(dataset["test"])
45
  test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
46
 
47
- # for i, batch in enumerate(train_dl):
48
- # x, x_f, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
49
- # print(x.shape, x_f.shape, y.shape)
50
- # if i > 10:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  # exit()
53
 
54
  # model = DualEncoder(model_args, model_args_f, conformer_args)
55
  # model = FasterKAN([18000,64,64,16,1])
56
  model = CNNKan(model_args, conformer_args, kan_args.get_dict())
 
57
  # model.kan.speed()
58
  # model = KanEncoder(kan_args.get_dict())
59
  model = model.to(local_rank)
60
- state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
61
- new_state_dict = OrderedDict()
62
- for key, value in state_dict.items():
63
- if key.startswith('module.'):
64
- key = key[7:]
65
- new_state_dict[key] = value
66
- missing, unexpected = model.load_state_dict(new_state_dict)
 
 
67
  # model = DDP(model, device_ids=[local_rank], output_device=local_rank)
68
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
69
  print(f"Number of parameters: {num_params}")
@@ -92,7 +133,7 @@ fit_res = trainer.fit(num_epochs=100, device=local_rank,
92
  output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
93
  with open(output_filename, "w") as f:
94
  json.dump(fit_res, f, indent=2)
95
- preds, acc = trainer.predict(test_dl, local_rank)
96
  print(f"Accuracy: {acc}")
97
 
98
 
 
1
  from torch.utils.data import DataLoader
2
  from .utils.data import FFTDataset, SplitDataset
3
  from datasets import load_dataset
4
+ from .utils.train import Trainer, XGBoostTrainer
5
+ from .utils.models import CNNKan, KanEncoder, CNNKanFeaturesEncoder
6
  from .utils.data_utils import *
7
  from huggingface_hub import login
8
  import yaml
9
  import datetime
10
  import json
11
  import numpy as np
12
+ import pandas as pd
13
+ import seaborn as sns
14
+ import matplotlib.pyplot as plt
15
  from collections import OrderedDict
16
 
17
  # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
22
  exp_num = data_args.exp_num
23
  model_name = data_args.model_name
24
  model_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder'])
25
+ mlp_args = Container(**yaml.safe_load(open(args_dir, 'r'))['MLP'])
26
  model_args_f = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
27
  conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
28
  kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN'])
29
+ boost_args = Container(**yaml.safe_load(open(args_dir, 'r'))['XGBoost'])
30
  if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
31
  os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
32
 
 
49
  test_ds = FFTDataset(dataset["test"])
50
  test_dl = DataLoader(test_ds,batch_size=data_args.batch_size, collate_fn=collate_fn)
51
 
52
+ # data = []
53
+ #
54
+ # # Iterate over the dataset
55
+ # for i, batch in enumerate(train_ds):
56
+ # label = batch['label']
57
+ # features = batch['audio']['features']
58
+ #
59
+ # # Flatten the nested dictionary structure
60
+ # feature_dict = {'label': label}
61
+ # for k, v in features.items():
62
+ # if isinstance(v, dict):
63
+ # for sub_k, sub_v in v.items():
64
+ # feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean)
65
+ # else:
66
+ # print(k, v.shape) # Aggregate (e.g., mean)
67
+ #
68
+ # data.append(feature_dict)
69
+ # print(i)
70
+ #
71
+ # if i > 1000: # Limit to 10 iterations
72
  # break
73
+ #
74
+ # # Convert to DataFrame
75
+ # df = pd.DataFrame(data)
76
+
77
+ # Plot distributions colored by label
78
+ # plt.figure()
79
+ # for col in df.columns:
80
+ # if col != 'label':
81
+ # sns.kdeplot(df, x=col, hue='label', fill=True, alpha=0.5)
82
+ # plt.title(f'Distribution of {col}')
83
+ # plt.show()
84
+ # exit()
85
+
86
+ # trainer = XGBoostTrainer(boost_args.get_dict(), train_ds, val_ds, test_ds)
87
+ # res = trainer.fit()
88
+ # trainer.predict()
89
+ # trainer.plot_results(res)
90
  # exit()
91
 
92
  # model = DualEncoder(model_args, model_args_f, conformer_args)
93
  # model = FasterKAN([18000,64,64,16,1])
94
  model = CNNKan(model_args, conformer_args, kan_args.get_dict())
95
+ # model = CNNKanFeaturesEncoder(model_args, mlp_args, kan_args.get_dict())
96
  # model.kan.speed()
97
  # model = KanEncoder(kan_args.get_dict())
98
  model = model.to(local_rank)
99
+
100
+ # state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
101
+ # new_state_dict = OrderedDict()
102
+ # for key, value in state_dict.items():
103
+ # if key.startswith('module.'):
104
+ # key = key[7:]
105
+ # new_state_dict[key] = value
106
+ # missing, unexpected = model.load_state_dict(new_state_dict)
107
+
108
  # model = DDP(model, device_ids=[local_rank], output_device=local_rank)
109
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
110
  print(f"Number of parameters: {num_params}")
 
133
  output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
134
  with open(output_filename, "w") as f:
135
  json.dump(fit_res, f, indent=2)
136
+ preds, tru, acc = trainer.predict(test_dl, local_rank)
137
  print(f"Accuracy: {acc}")
138
 
139
 
tasks/run_inr.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn
2
+ from torch.utils.data import DataLoader
3
+ from utils.data import FFTDataset, SplitDataset, AudioINRDataset
4
+ from datasets import load_dataset
5
+ from utils.train import Trainer, INRTrainer
6
+ from utils.models import MultiGraph, ImplicitEncoder
7
+ from omegaconf import OmegaConf
8
+
9
+ # from .utils.models import CNNKan, KanEncoder
10
+ from utils.inr import INR
11
+ from utils.data_utils import *
12
+ from huggingface_hub import login
13
+ import yaml
14
+ import datetime
15
+ import json
16
+ import numpy as np
17
+ from tqdm import tqdm
18
+ import matplotlib.pyplot as plt
19
+ from scipy.signal import savgol_filter as savgol
20
+ from utils.kan import FasterKAN
21
+ from utils.relational_transformer import RelationalTransformer
22
+ from collections import OrderedDict
23
+ def plot_results(dims, i, data, losses, pred_values):
24
+ data = savgol(data.cpu().detach().numpy(), window_length=250, polyorder=1)
25
+ pred_values = pred_values.transpose(-1, -2).unflatten(-1, data.shape[-2:]).squeeze(0).cpu().detach().numpy()
26
+ pred_values = (pred_values - np.min(pred_values)) / (np.max(pred_values) - np.min(pred_values))
27
+ data = (data - np.min(data)) / (np.max(data) - np.min(data))
28
+ plt.plot(data.squeeze())
29
+ plt.plot(pred_values.squeeze())
30
+ # axes[0].set_title('Original')
31
+ # axes[1].set_title('Reconstruction')
32
+ plt.show()
33
+ # plt.plot(np.arange(len(losses)), losses)
34
+ # plt.xlabel('Iteration')
35
+ # plt.ylabel('Reconstruction MSE Error')
36
+ # plt.show()
37
+
38
+ # local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+ current_date = datetime.date.today().strftime("%Y-%m-%d")
40
+ datetime_dir = f"frugal_{current_date}"
41
+ args_dir = 'utils/config.yaml'
42
+ data_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Data'])
43
+ exp_num = data_args.exp_num
44
+ model_name = data_args.model_name
45
+ rt_args = Container(**yaml.safe_load(open(args_dir, 'r'))['RelationalTransformer'])
46
+ cnn_args = Container(**yaml.safe_load(open(args_dir, 'r'))['CNNEncoder_f'])
47
+ conformer_args = Container(**yaml.safe_load(open(args_dir, 'r'))['Conformer'])
48
+ kan_args = Container(**yaml.safe_load(open(args_dir, 'r'))['KAN_INR'])
49
+ inr_args = Container(**yaml.safe_load(open(args_dir, 'r'))['INR'])
50
+ if not os.path.exists(f"{data_args.log_dir}/{datetime_dir}"):
51
+ os.makedirs(f"{data_args.log_dir}/{datetime_dir}")
52
+
53
+ with open("../../logs/token.txt", "r") as f:
54
+ api_key = f.read()
55
+
56
+ # local_rank, world_size, gpus_per_node = setup()
57
+ local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
+ login(api_key)
59
+ dataset = load_dataset("rfcx/frugalai", streaming=True)
60
+
61
+ train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True)
62
+
63
+ train_dl = DataLoader(train_ds, batch_size=data_args.batch_size)
64
+
65
+ val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False)
66
+
67
+ val_dl = DataLoader(val_ds, batch_size=data_args.batch_size)
68
+
69
+ test_ds = AudioINRDataset(FFTDataset(dataset["test"]))
70
+ test_dl = DataLoader(test_ds, batch_size=data_args.batch_size)
71
+
72
+ # for i, batch in enumerate(train_ds):
73
+ # fft_phase, fft_mag, audio = batch['audio']['fft_phase'], batch['audio']['fft_mag'], batch['audio']['array']
74
+ # label = batch['label']
75
+ # fig, axes = plt.subplots(nrows=1, ncols=3)
76
+ # axes = axes.flatten()
77
+ # axes[0].plot(fft_phase)
78
+ # axes[1].plot(fft_mag)
79
+ # axes[2].plot(audio)
80
+ # fig.suptitle(label)
81
+ # plt.tight_layout()
82
+ # plt.show()
83
+ # if i > 20:
84
+ # break
85
+ # model = DualEncoder(model_args, model_args_f, conformer_args)
86
+ # model = FasterKAN([18000,64,64,16,1])
87
+ # model = INR(in_features=1)
88
+ # model.kan.speed()
89
+ # model = KanEncoder(kan_args.get_dict())
90
+ # model = model.to(local_rank)
91
+
92
+ # state_dict = torch.load(data_args.checkpoint_path, map_location=torch.device('cpu'))
93
+ # new_state_dict = OrderedDict()
94
+ # for key, value in state_dict.items():
95
+ # if key.startswith('module.'):
96
+ # key = key[7:]
97
+ # new_state_dict[key] = value
98
+ # missing, unexpected = model.load_state_dict(new_state_dict)
99
+
100
+ # model = DDP(model, device_ids=[local_rank], output_device=local_rank)
101
+ # num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
102
+ # print(f"Number of parameters: {num_params}")
103
+ #
104
+ # optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
105
+ # total_steps = int(data_args.num_epochs) * 1000
106
+ # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
107
+ # T_max=total_steps,
108
+ # eta_min=float((5e-4) / 10))
109
+
110
+ loss_fn = torch.nn.BCEWithLogitsLoss()
111
+ inr_criterion = torch.nn.MSELoss()
112
+
113
+ # for i, batch in enumerate(train_ds):
114
+ # coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array']
115
+ # coords = coords.to(local_rank)
116
+ # fft = fft.to(local_rank)
117
+ # audio = audio.to(local_rank)
118
+ # values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1)
119
+ # # model = INR(hidden_features=128, n_layers=3,
120
+ # # in_features=1,
121
+ # # out_features=1).to(local_rank)
122
+ # model = FasterKAN(**kan_args.get_dict()).to(local_rank)
123
+ # optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
124
+ # pbar = tqdm(range(200))
125
+ # losses = []
126
+ # print(coords.shape)
127
+ # for t in pbar:
128
+ # optimizer.zero_grad()
129
+ # pred_values = model(coords.to(local_rank)).float()
130
+ # loss = inr_criterion(pred_values, values)
131
+ # loss.backward()
132
+ # optimizer.step()
133
+ # pbar.set_description(f'loss: {loss.item()}')
134
+ # losses.append(loss.item())
135
+ # state_dict = model.state_dict()
136
+ # torch.save(state_dict, 'test')
137
+ # # print(f'Sample {i+offset} label {label} saved in {inr_path}')
138
+ # plot_results(1, i, fft, losses, pred_values)
139
+ # #
140
+ # exit()
141
+
142
+
143
+ # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
144
+ # print(f"Missing keys: {missing}")
145
+ # print(f"Unexpected keys: {unexpected}")
146
+ layer_layout = [inr_args.in_features] + [inr_args.hidden_features for _ in range(inr_args.n_layers)] + [inr_args.out_features]
147
+
148
+ graph_constructor = OmegaConf.create(
149
+ {
150
+ "_target_": "utils.graph_constructor.GraphConstructor",
151
+ "_recursive_": False,
152
+ "_convert_": "all",
153
+ "d_in": 1,
154
+ "d_edge_in": 1,
155
+ "zero_out_bias": False,
156
+ "zero_out_weights": False,
157
+ "sin_emb": True,
158
+ "sin_emb_dim": rt_args.d_node,
159
+ "use_pos_embed": False,
160
+ "input_layers": 1,
161
+ "inp_factor": 1,
162
+ "num_probe_features": 0,
163
+ "inr_model": None,
164
+ "stats": None,
165
+ "sparsify": False,
166
+ 'sym_edges': False,
167
+ }
168
+ )
169
+
170
+ rt_model = RelationalTransformer(layer_layout=layer_layout, graph_constructor=graph_constructor,
171
+ **rt_args.get_dict()).to(local_rank)
172
+ rt_model.proj_out= torch.nn.Identity()
173
+ multi_graph = MultiGraph(rt_model, cnn_args)
174
+ implicit_net = INR(**inr_args.get_dict())
175
+ model = ImplicitEncoder(implicit_net, multi_graph).to(local_rank)
176
+ num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
177
+ print(f"Number of parameters: {num_parameters}")
178
+ optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
179
+ trainer = Trainer(model=model, optimizer=optimizer,
180
+ criterion=loss_fn, output_dim=1, scaler=None,
181
+ scheduler=None, train_dataloader=train_dl,
182
+ val_dataloader=val_dl, device=local_rank,
183
+ exp_num=datetime_dir, log_path=data_args.log_dir,
184
+ range_update=None,
185
+ accumulation_step=1, max_iter=100,
186
+ exp_name=f"frugal_kan_{exp_num}")
187
+ fit_res = trainer.fit(num_epochs=100, device=local_rank,
188
+ early_stopping=10, only_p=False, best='loss', conf=True)
189
+ output_filename = f'{data_args.log_dir}/{datetime_dir}/{model_name}_frugal_{exp_num}.json'
190
+ with open(output_filename, "w") as f:
191
+ json.dump(fit_res, f, indent=2)
192
+ preds, acc = trainer.predict(test_dl, local_rank)
193
+ print(f"Accuracy: {acc}")
194
+
195
+
196
+
197
+
198
+
199
+
200
+
tasks/tasks/models/frugal_2025-01-28/CNNEncoder_frugal_2.json ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "num_epochs": 100,
3
+ "train_loss": [
4
+ 0,
5
+ 0,
6
+ 0,
7
+ 0,
8
+ 0,
9
+ 0,
10
+ 0,
11
+ 0,
12
+ 0,
13
+ 0,
14
+ 0,
15
+ 0,
16
+ 0,
17
+ 0,
18
+ 0,
19
+ 0,
20
+ 0,
21
+ 0,
22
+ 0,
23
+ 0,
24
+ 0,
25
+ 0,
26
+ 0,
27
+ 0,
28
+ 0,
29
+ 0,
30
+ 0,
31
+ 0,
32
+ 0,
33
+ 0,
34
+ 0,
35
+ 0,
36
+ 0,
37
+ 0,
38
+ 0,
39
+ 0,
40
+ 0,
41
+ 0,
42
+ 0,
43
+ 0,
44
+ 0,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 0,
49
+ 0,
50
+ 0,
51
+ 0,
52
+ 0,
53
+ 0,
54
+ 0,
55
+ 0,
56
+ 0,
57
+ 0,
58
+ 0,
59
+ 0,
60
+ 0,
61
+ 0,
62
+ 0,
63
+ 0,
64
+ 0,
65
+ 0,
66
+ 0,
67
+ 0,
68
+ 0,
69
+ 0,
70
+ 0,
71
+ 0,
72
+ 0,
73
+ 0,
74
+ 0,
75
+ 0,
76
+ 0,
77
+ 0,
78
+ 0,
79
+ 0,
80
+ 0,
81
+ 0,
82
+ 0,
83
+ 0,
84
+ 0,
85
+ 0,
86
+ 0,
87
+ 0,
88
+ 0,
89
+ 0,
90
+ 0,
91
+ 0,
92
+ 0,
93
+ 0,
94
+ 0,
95
+ 0,
96
+ 0,
97
+ 0,
98
+ 0,
99
+ 0,
100
+ 0,
101
+ 0,
102
+ 0,
103
+ 0,
104
+ 0,
105
+ 0,
106
+ 0,
107
+ 0,
108
+ 0,
109
+ 0,
110
+ 0,
111
+ 0,
112
+ 0,
113
+ 0,
114
+ 0,
115
+ 0,
116
+ 0,
117
+ 0,
118
+ 0,
119
+ 0,
120
+ 0,
121
+ 0,
122
+ 0,
123
+ 0,
124
+ 0,
125
+ 0,
126
+ 0,
127
+ 0,
128
+ 0,
129
+ 0,
130
+ 0,
131
+ 0,
132
+ 0,
133
+ 0,
134
+ 0,
135
+ 0,
136
+ 0,
137
+ 0,
138
+ 0,
139
+ 0,
140
+ 0,
141
+ 0,
142
+ 0,
143
+ 0,
144
+ 0,
145
+ 0,
146
+ 0,
147
+ 0,
148
+ 0,
149
+ 0,
150
+ 0,
151
+ 0,
152
+ 0,
153
+ 0,
154
+ 0,
155
+ 0,
156
+ 0,
157
+ 0,
158
+ 0,
159
+ 0,
160
+ 0,
161
+ 0,
162
+ 0,
163
+ 0,
164
+ 0,
165
+ 0,
166
+ 0,
167
+ 0,
168
+ 0,
169
+ 0,
170
+ 0,
171
+ 0,
172
+ 0,
173
+ 0,
174
+ 0,
175
+ 0,
176
+ 0,
177
+ 0,
178
+ 0,
179
+ 0,
180
+ 0,
181
+ 0,
182
+ 0,
183
+ 0,
184
+ 0,
185
+ 0,
186
+ 0,
187
+ 0,
188
+ 0,
189
+ 0,
190
+ 0
191
+ ],
192
+ "val_loss": [
193
+ 0.7067973613739014,
194
+ 0.6959081888198853,
195
+ 0.7170985341072083,
196
+ 0.6722545623779297,
197
+ 0.7166763544082642,
198
+ 0.7339211702346802,
199
+ 0.7830350399017334,
200
+ 0.7616254091262817,
201
+ 0.7465198040008545,
202
+ 0.7474430799484253,
203
+ 0.7112622261047363,
204
+ 0.8131649494171143,
205
+ 0.7090165019035339,
206
+ 0.697528600692749,
207
+ 0.7792525291442871,
208
+ 0.700302243232727,
209
+ 0.7315454483032227,
210
+ 0.7067973613739014,
211
+ 0.6959081888198853,
212
+ 0.7170985341072083,
213
+ 0.6722545623779297,
214
+ 0.7166763544082642,
215
+ 0.7339211702346802,
216
+ 0.7830350399017334,
217
+ 0.7616254091262817,
218
+ 0.7465198040008545,
219
+ 0.7474430799484253,
220
+ 0.7112622261047363,
221
+ 0.8131649494171143,
222
+ 0.7090165019035339,
223
+ 0.697528600692749,
224
+ 0.7792525291442871,
225
+ 0.700302243232727,
226
+ 0.7315454483032227,
227
+ 0.7067973613739014,
228
+ 0.6959081888198853,
229
+ 0.7170985341072083,
230
+ 0.6722545623779297,
231
+ 0.7166763544082642,
232
+ 0.7339211702346802,
233
+ 0.7830350399017334,
234
+ 0.7616254091262817,
235
+ 0.7465198040008545,
236
+ 0.7474430799484253,
237
+ 0.7112622261047363,
238
+ 0.8131649494171143,
239
+ 0.7090165019035339,
240
+ 0.697528600692749,
241
+ 0.7792525291442871,
242
+ 0.700302243232727,
243
+ 0.7315454483032227,
244
+ 0.7067973613739014,
245
+ 0.6959081888198853,
246
+ 0.7170985341072083,
247
+ 0.6722545623779297,
248
+ 0.7166763544082642,
249
+ 0.7339211702346802,
250
+ 0.7830350399017334,
251
+ 0.7616254091262817,
252
+ 0.7465198040008545,
253
+ 0.7474430799484253,
254
+ 0.7112622261047363,
255
+ 0.8131649494171143,
256
+ 0.7090165019035339,
257
+ 0.697528600692749,
258
+ 0.7792525291442871,
259
+ 0.700302243232727,
260
+ 0.7315454483032227,
261
+ 0.7067973613739014,
262
+ 0.6959081888198853,
263
+ 0.7170985341072083,
264
+ 0.6722545623779297,
265
+ 0.7166763544082642,
266
+ 0.7339211702346802,
267
+ 0.7830350399017334,
268
+ 0.7616254091262817,
269
+ 0.7465198040008545,
270
+ 0.7474430799484253,
271
+ 0.7112622261047363,
272
+ 0.8131649494171143,
273
+ 0.7090165019035339,
274
+ 0.697528600692749,
275
+ 0.7792525291442871,
276
+ 0.700302243232727,
277
+ 0.7315454483032227,
278
+ 0.7067973613739014,
279
+ 0.6959081888198853,
280
+ 0.7170985341072083,
281
+ 0.6722545623779297,
282
+ 0.7166763544082642,
283
+ 0.7339211702346802,
284
+ 0.7830350399017334,
285
+ 0.7616254091262817,
286
+ 0.7465198040008545,
287
+ 0.7474430799484253,
288
+ 0.7112622261047363,
289
+ 0.8131649494171143,
290
+ 0.7090165019035339,
291
+ 0.697528600692749,
292
+ 0.7792525291442871,
293
+ 0.700302243232727,
294
+ 0.7315454483032227,
295
+ 0.7067973613739014,
296
+ 0.6959081888198853,
297
+ 0.7170985341072083,
298
+ 0.6722545623779297,
299
+ 0.7166763544082642,
300
+ 0.7339211702346802,
301
+ 0.7830350399017334,
302
+ 0.7616254091262817,
303
+ 0.7465198040008545,
304
+ 0.7474430799484253,
305
+ 0.7112622261047363,
306
+ 0.8131649494171143,
307
+ 0.7090165019035339,
308
+ 0.697528600692749,
309
+ 0.7792525291442871,
310
+ 0.700302243232727,
311
+ 0.7315454483032227,
312
+ 0.7067973613739014,
313
+ 0.6959081888198853,
314
+ 0.7170985341072083,
315
+ 0.6722545623779297,
316
+ 0.7166763544082642,
317
+ 0.7339211702346802,
318
+ 0.7830350399017334,
319
+ 0.7616254091262817,
320
+ 0.7465198040008545,
321
+ 0.7474430799484253,
322
+ 0.7112622261047363,
323
+ 0.8131649494171143,
324
+ 0.7090165019035339,
325
+ 0.697528600692749,
326
+ 0.7792525291442871,
327
+ 0.700302243232727,
328
+ 0.7315454483032227,
329
+ 0.7067973613739014,
330
+ 0.6959081888198853,
331
+ 0.7170985341072083,
332
+ 0.6722545623779297,
333
+ 0.7166763544082642,
334
+ 0.7339211702346802,
335
+ 0.7830350399017334,
336
+ 0.7616254091262817,
337
+ 0.7465198040008545,
338
+ 0.7474430799484253,
339
+ 0.7112622261047363,
340
+ 0.8131649494171143,
341
+ 0.7090165019035339,
342
+ 0.697528600692749,
343
+ 0.7792525291442871,
344
+ 0.700302243232727,
345
+ 0.7315454483032227,
346
+ 0.7067973613739014,
347
+ 0.6959081888198853,
348
+ 0.7170985341072083,
349
+ 0.6722545623779297,
350
+ 0.7166763544082642,
351
+ 0.7339211702346802,
352
+ 0.7830350399017334,
353
+ 0.7616254091262817,
354
+ 0.7465198040008545,
355
+ 0.7474430799484253,
356
+ 0.7112622261047363,
357
+ 0.8131649494171143,
358
+ 0.7090165019035339,
359
+ 0.697528600692749,
360
+ 0.7792525291442871,
361
+ 0.700302243232727,
362
+ 0.7315454483032227,
363
+ 0.7067973613739014,
364
+ 0.6959081888198853,
365
+ 0.7170985341072083,
366
+ 0.6722545623779297,
367
+ 0.7166763544082642,
368
+ 0.7339211702346802,
369
+ 0.7830350399017334,
370
+ 0.7616254091262817,
371
+ 0.7465198040008545,
372
+ 0.7474430799484253,
373
+ 0.7112622261047363,
374
+ 0.8131649494171143,
375
+ 0.7090165019035339,
376
+ 0.697528600692749,
377
+ 0.7792525291442871,
378
+ 0.700302243232727,
379
+ 0.7315454483032227
380
+ ],
381
+ "train_acc": [
382
+ 0.0,
383
+ 0.0,
384
+ 0.0,
385
+ 0.0,
386
+ 0.0,
387
+ 0.0,
388
+ 0.0,
389
+ 0.0,
390
+ 0.0,
391
+ 0.0,
392
+ 0.0
393
+ ],
394
+ "val_acc": [
395
+ 0.0,
396
+ 0.0,
397
+ 0.0,
398
+ 0.0,
399
+ 0.0,
400
+ 0.0,
401
+ 0.0,
402
+ 0.0,
403
+ 0.0,
404
+ 0.0,
405
+ 0.0
406
+ ],
407
+ "lrs": [
408
+ 0.001,
409
+ 0.001,
410
+ 0.001,
411
+ 0.001,
412
+ 0.001,
413
+ 0.001,
414
+ 0.001,
415
+ 0.001,
416
+ 0.001,
417
+ 0.001,
418
+ 0.001
419
+ ]
420
+ }
tasks/tasks/models/frugal_2025-01-28/frugal_kan_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b22195a96ca5dd9cf4d20bb1886c6d22327d42ea3fbe4b9d86b9a281a132024
3
+ size 820491
tasks/utils/config.yaml CHANGED
@@ -12,7 +12,7 @@ Data:
12
  max_days_lc: 270
13
  lc_freq: 0.0208
14
  create_umap: True
15
- checkpoint_path: 'tasks/models/frugal_2025-01-27/frugal_kan_2.pth'
16
 
17
  CNNEncoder:
18
  # Model
@@ -27,9 +27,14 @@ CNNEncoder:
27
  load_checkpoint: False
28
  checkpoint_num: 1
29
  activation: "silu"
30
- sine_w0: 1.0
31
  avg_output: False
32
 
 
 
 
 
 
33
  KAN:
34
  layers_hidden: [1125,32,8,1]
35
  grid_min: -1.2
@@ -37,9 +42,16 @@ KAN:
37
  num_grids: 8
38
  exponent: 2
39
 
 
 
 
 
 
 
 
40
  CNNEncoder_f:
41
  # Model
42
- in_channels: 1
43
  num_layers: 4
44
  stride: 1
45
  encoder_dims: [32,64,128]
@@ -64,10 +76,39 @@ Conformer:
64
  dropout_p: 0.2
65
  norm: "postnorm"
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  Optimization:
69
  # Optimization
70
  max_lr: 1e-5
71
  weight_decay: 5e-6
72
  warmup_pct: 0.3
73
- steps_per_epoch: 3500
 
12
  max_days_lc: 270
13
  lc_freq: 0.0208
14
  create_umap: True
15
+ checkpoint_path: 'tasks/models/frugal_2025-01-29/frugal_kan_2.pth'
16
 
17
  CNNEncoder:
18
  # Model
 
27
  load_checkpoint: False
28
  checkpoint_num: 1
29
  activation: "silu"
30
+ sine_w0: 30.0
31
  avg_output: False
32
 
33
+ MLP:
34
+ input_dim: 6
35
+ hidden_dims: [16,32]
36
+ dropout: 0.2
37
+
38
  KAN:
39
  layers_hidden: [1125,32,8,1]
40
  grid_min: -1.2
 
42
  num_grids: 8
43
  exponent: 2
44
 
45
+ KAN_INR:
46
+ layers_hidden: [1,1024,128,128,1]
47
+ grid_min: -1.2
48
+ grid_max: 1.2
49
+ num_grids: 8
50
+ exponent: 2
51
+
52
  CNNEncoder_f:
53
  # Model
54
+ in_channels: 32
55
  num_layers: 4
56
  stride: 1
57
  encoder_dims: [32,64,128]
 
76
  dropout_p: 0.2
77
  norm: "postnorm"
78
 
79
+ RelationalTransformer:
80
+ d_node: 32
81
+ d_edge: 32
82
+ d_attn_hid: 16
83
+ d_node_hid: 16
84
+ d_edge_hid: 16
85
+ d_out_hid: 16
86
+ d_out: 1
87
+ n_layers: 4
88
+ n_heads: 4
89
+ dropout: 0.1
90
+
91
+
92
+ INR:
93
+ in_features : 2
94
+ n_layers : 2
95
+ hidden_features : 64
96
+ out_features : 32
97
+
98
+ XGBoost:
99
+ objective : 'binary:logistic'
100
+ eval_metric : 'logloss'
101
+ use_label_encoder : False
102
+ n_estimators : 500
103
+ learning_rate : 0.1
104
+ max_depth : 5
105
+ subsample : 0.8
106
+ colsample_bytree : 0.8
107
+ random_state : 42
108
 
109
  Optimization:
110
  # Optimization
111
  max_lr: 1e-5
112
  weight_decay: 5e-6
113
  warmup_pct: 0.3
114
+ steps_per_epoch: 3500
tasks/utils/data.py CHANGED
@@ -1,11 +1,38 @@
1
  import torch
2
  from torch.utils.data import IterableDataset
3
- from torch.fft import fft
4
  import torch.nn.functional as F
5
  from itertools import tee
6
  import random
7
  import torchaudio.transforms as T
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  class SplitDataset(IterableDataset):
11
  def __init__(self, dataset, is_train=True, train_ratio=0.8):
@@ -28,22 +55,101 @@ class FFTDataset(IterableDataset):
28
  def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
29
  self.dataset = original_dataset
30
  self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
 
31
  self.max_len = max_len
32
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def __iter__(self):
34
  for item in self.dataset:
35
- # Assuming your audio data is in item['audio']
36
- # Modify this based on your actual data structure
37
- audio_data = torch.tensor(item['audio']['array']).float()
38
- # pad audio
39
- # if len(audio_data) == 0:
40
- # continue
41
  pad_len = self.max_len - len(audio_data)
42
  audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
43
  audio_data = self.resampler(audio_data)
 
 
44
  fft_data = fft(audio_data)
45
-
46
- # Update the item with FFT data
47
- item['audio']['fft'] = fft_data
48
- item['audio']['array'] = audio_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  yield item
 
1
  import torch
2
  from torch.utils.data import IterableDataset
3
+ from torch.fft import fft, fftshift
4
  import torch.nn.functional as F
5
  from itertools import tee
6
  import random
7
  import torchaudio.transforms as T
8
+ import hashlib
9
+ from typing import NamedTuple, Tuple, Union
10
+ from .transforms import compute_all_features
11
 
12
+ from scipy.signal import savgol_filter as savgol
13
+
14
+
15
+ class WeightsBatch(NamedTuple):
16
+ weights: Tuple
17
+ biases: Tuple
18
+ label: Union[torch.Tensor, int]
19
+
20
+ def _assert_same_len(self):
21
+ assert len(set([len(t) for t in self])) == 1
22
+
23
+ def as_dict(self):
24
+ return self._asdict()
25
+
26
+ def to(self, device):
27
+ """move batch to device"""
28
+ return self.__class__(
29
+ weights=tuple(w.to(device) for w in self.weights),
30
+ biases=tuple(w.to(device) for w in self.biases),
31
+ label=self.label.to(device),
32
+ )
33
+
34
+ def __len__(self):
35
+ return len(self.weights[0])
36
 
37
  class SplitDataset(IterableDataset):
38
  def __init__(self, dataset, is_train=True, train_ratio=0.8):
 
55
  def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
56
  self.dataset = original_dataset
57
  self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
58
+ self.target_sample_rate = target_sample_rate
59
  self.max_len = max_len
60
+
61
+
62
+ def normalize_audio(self, audio):
63
+ """Normalize audio to [0, 1] range"""
64
+ audio_min = audio.min()
65
+ audio_max = audio.max()
66
+ audio = (audio - audio_min) / (audio_max - audio_min)
67
+ return audio
68
+
69
+ def generate_unique_id(self, array):
70
+ # Convert the array to bytes
71
+ array_bytes = array.tobytes()
72
+ # Hash the bytes using SHA256
73
+ hash_object = hashlib.sha256(array_bytes)
74
+ # Return the hexadecimal representation of the hash
75
+ return hash_object.hexdigest()
76
+
77
  def __iter__(self):
78
  for item in self.dataset:
79
+ # audio_data = savgol(item['audio']['array'], 500, polyorder=1)
80
+ audio_data = item['audio']['array']
81
+ # item['id'] = self.generate_unique_id(audio_data)
82
+ audio_data = torch.tensor(audio_data).float()
83
+
 
84
  pad_len = self.max_len - len(audio_data)
85
  audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
86
  audio_data = self.resampler(audio_data)
87
+
88
+ audio_data = self.normalize_audio(audio_data)
89
  fft_data = fft(audio_data)
90
+ magnitude = torch.abs(fft_data)
91
+ phase = torch.angle(fft_data)
92
+ # features = compute_all_features(audio_data, sample_rate=self.target_sample_rate)
93
+ # features_arr = torch.tensor([v for _, v in features['frequency_domain'].items()])
94
+ magnitude_centered = fftshift(magnitude)
95
+ phase_centered = fftshift(phase)
96
+ # cwt = features['cwt_power']
97
+
98
+ # Optionally, remove the DC component
99
+ magnitude_centered[len(magnitude_centered) // 2] = 0 # Set DC component to zero
100
+
101
+ item['audio']['fft_mag'] = torch.nan_to_num(magnitude_centered, 0)
102
+ item['audio']['fft_phase'] = torch.nan_to_num(phase_centered, 0)
103
+ # item['audio']['cwt_mag'] = torch.nan_to_num(cwt, 0)
104
+ item['audio']['array'] = torch.nan_to_num(audio_data, 0)
105
+ # item['audio']['features'] = features
106
+ # item['audio']['features_arr'] = torch.nan_to_num(features_arr, 0)
107
+ yield item
108
+
109
+
110
+ class AudioINRDataset(IterableDataset):
111
+ def __init__(self, original_dataset, max_len=18000, sample_size=1024, dim=1, normalize=True):
112
+ """
113
+ Convert audio data into coordinate-value pairs for INR training.
114
+
115
+ Args:
116
+ original_dataset: Original audio dataset
117
+ max_len: Maximum length of audio to process
118
+ batch_size: Number of points to sample per audio clip
119
+ normalize: Whether to normalize the audio values to [0, 1]
120
+ """
121
+ self.dataset = original_dataset
122
+ self.max_len = max_len
123
+ self.dim = dim
124
+ self.normalize = normalize
125
+ self.sample_size = sample_size
126
+
127
+ def get_coordinates(self, audio_len):
128
+ """Generate time coordinates"""
129
+ # Create normalized time coordinates in [0, 1]
130
+ coords = torch.linspace(0, 1, audio_len).unsqueeze(-1).expand(audio_len, self.dim)
131
+ return coords # Shape: [audio_len, 1]
132
+
133
+ def sample_points(self, coords, values):
134
+ """Randomly sample points from the audio"""
135
+ if len(coords) > self.sample_size:
136
+ idx = torch.randperm(len(coords))[:self.sample_size]
137
+ coords = coords[idx]
138
+ values = values[idx]
139
+ return coords, values
140
+
141
+ def __iter__(self):
142
+ for item in self.dataset:
143
+ # Get audio data
144
+ audio_data = torch.tensor(item['audio']['array']).float()
145
+
146
+ # Generate coordinates
147
+ coords = self.get_coordinates(len(audio_data))
148
+
149
+ item['audio']['coords'] = coords
150
+
151
+ # Sample random points
152
+ # coords, values = self.sample_points(coords, audio_data)
153
+
154
+ # Create the INR training sample
155
  yield item
tasks/utils/data_utils.py CHANGED
@@ -6,18 +6,25 @@ from torch.nn.utils.rnn import pad_sequence
6
  def collate_fn(batch):
7
  # Extract audio arrays and FFT data from the batch of dictionaries
8
  audio_arrays = [torch.tensor(item['audio']['array']) for item in batch]
9
- fft_arrays = [torch.tensor(item['audio']['fft']) for item in batch]
 
 
 
10
  labels = [torch.tensor(item['label']) for item in batch]
11
 
12
  # Pad both sequences
13
  padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
14
  padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
 
15
 
16
  # Return as dictionary with the same structure
17
  return {
18
  'audio': {
19
  'array': padded_audio,
20
- 'fft': padded_fft
 
 
 
21
  },
22
  'label': torch.stack(labels)
23
 
 
6
  def collate_fn(batch):
7
  # Extract audio arrays and FFT data from the batch of dictionaries
8
  audio_arrays = [torch.tensor(item['audio']['array']) for item in batch]
9
+ fft_arrays = [torch.tensor(item['audio']['fft_mag']) for item in batch]
10
+ # cwt_arrays = [torch.tensor(item['audio']['cwt_mag']) for item in batch]
11
+ # features = [item['audio']['features'] for item in batch]
12
+ # features_arr = torch.stack([item['audio']['features_arr'] for item in batch])
13
  labels = [torch.tensor(item['label']) for item in batch]
14
 
15
  # Pad both sequences
16
  padded_audio = pad_sequence(audio_arrays, batch_first=True, padding_value=0)
17
  padded_fft = pad_sequence(fft_arrays, batch_first=True, padding_value=0)
18
+ # padded_features = pad_sequence(features_arr, batch_first=True, padding_value=0)
19
 
20
  # Return as dictionary with the same structure
21
  return {
22
  'audio': {
23
  'array': padded_audio,
24
+ 'fft_mag': padded_fft,
25
+ # 'features': features,
26
+ # 'features_arr': features_arr,
27
+ # 'cwt_mag': padded_cwt,
28
  },
29
  'label': torch.stack(labels)
30
 
tasks/utils/dfs/test.csv ADDED
The diff for this file is too large to render. See raw diff
 
tasks/utils/dfs/train.csv ADDED
The diff for this file is too large to render. See raw diff
 
tasks/utils/dfs/val.csv ADDED
The diff for this file is too large to render. See raw diff
 
tasks/utils/graph_constructor.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from rff.layers import GaussianEncoding
4
+
5
+ # from nn.probe_features import GraphProbeFeatures
6
+
7
+
8
+ def sparsify_graph(edges, fraction=0.1):
9
+ abs_edges = torch.abs(edges)
10
+ flat_abs_tensor = abs_edges.flatten()
11
+ sorted_tensor, _ = torch.sort(flat_abs_tensor, descending=True)
12
+ num_elements = flat_abs_tensor.numel()
13
+ top_k = int(num_elements * fraction)
14
+ topk_values, topk_indices = torch.topk(flat_abs_tensor, top_k)
15
+ mask = torch.zeros_like(flat_abs_tensor, dtype=torch.bool)
16
+ mask[topk_indices] = True
17
+ mask = mask.view(edges.shape)
18
+ return mask
19
+
20
+ def batch_to_graphs(
21
+ weights,
22
+ biases,
23
+ weights_mean=None,
24
+ weights_std=None,
25
+ biases_mean=None,
26
+ biases_std=None,
27
+ sparsify=False,
28
+ sym_edges=False
29
+ ):
30
+ device = weights[0].device
31
+ bsz = weights[0].shape[0]
32
+ num_nodes = weights[0].shape[1] + sum(w.shape[2] for w in weights)
33
+
34
+ node_features = torch.zeros(bsz, num_nodes, biases[0].shape[-1], device=device)
35
+ edge_features = torch.zeros(
36
+ bsz, num_nodes, num_nodes, weights[0].shape[-1], device=device
37
+ )
38
+
39
+ row_offset = 0
40
+ col_offset = weights[0].shape[1] # no edge to input nodes
41
+
42
+ for i, w in enumerate(weights):
43
+ _, num_in, num_out, _ = w.shape
44
+ w_mean = weights_mean[i] if weights_mean is not None else 0
45
+ w_std = weights_std[i] if weights_std is not None else 1
46
+ w = (w - w_mean) / w_std
47
+ if sparsify:
48
+ w[~sparsify_graph(w)] = 0
49
+ edge_features[
50
+ :, row_offset : row_offset + num_in, col_offset : col_offset + num_out
51
+ ] = w
52
+ if sym_edges:
53
+ edge_features[
54
+ :, col_offset: col_offset + num_out, row_offset: row_offset + num_in
55
+ ] = torch.swapaxes(w, 1,2)
56
+ row_offset += num_in
57
+ col_offset += num_out
58
+
59
+ row_offset = weights[0].shape[1] # no bias in input nodes
60
+ for i, b in enumerate(biases):
61
+ _, num_out, _ = b.shape
62
+ b_mean = biases_mean[i] if biases_mean is not None else 0
63
+ b_std = biases_std[i] if biases_std is not None else 1
64
+ node_features[:, row_offset : row_offset + num_out] = (b - b_mean) / b_std
65
+ row_offset += num_out
66
+
67
+ return node_features, edge_features
68
+
69
+
70
+ class GraphConstructor(nn.Module):
71
+ def __init__(
72
+ self,
73
+ d_in,
74
+ d_edge_in,
75
+ d_node,
76
+ d_edge,
77
+ layer_layout,
78
+ rev_edge_features=False,
79
+ zero_out_bias=False,
80
+ zero_out_weights=False,
81
+ inp_factor=1,
82
+ input_layers=1,
83
+ sin_emb=False,
84
+ sin_emb_dim=128,
85
+ use_pos_embed=False,
86
+ num_probe_features=0,
87
+ inr_model=None,
88
+ stats=None,
89
+ sparsify=False,
90
+ sym_edges=False,
91
+ ):
92
+ super().__init__()
93
+ self.rev_edge_features = rev_edge_features
94
+ self.nodes_per_layer = layer_layout
95
+ self.zero_out_bias = zero_out_bias
96
+ self.zero_out_weights = zero_out_weights
97
+ self.use_pos_embed = use_pos_embed
98
+ self.stats = stats if stats is not None else {}
99
+ self._d_node = d_node
100
+ self._d_edge = d_edge
101
+ self.sparse = sparsify
102
+ self.sym_edges = sym_edges
103
+
104
+ self.pos_embed_layout = (
105
+ [1] * layer_layout[0] + layer_layout[1:-1] + [1] * layer_layout[-1]
106
+ )
107
+ self.pos_embed = nn.Parameter(torch.randn(len(self.pos_embed_layout), d_node))
108
+
109
+ if not self.zero_out_weights:
110
+ proj_weight = []
111
+ if sin_emb:
112
+ proj_weight.append(
113
+ GaussianEncoding(
114
+ sigma=inp_factor,
115
+ input_size=d_edge_in
116
+ + (2 * d_edge_in if rev_edge_features else 0),
117
+ encoded_size=sin_emb_dim,
118
+ )
119
+ )
120
+ proj_weight.append(nn.Linear(2 * sin_emb_dim, d_edge))
121
+ else:
122
+ proj_weight.append(
123
+ nn.Linear(
124
+ d_edge_in + (2 * d_edge_in if rev_edge_features else 0), d_edge
125
+ )
126
+ )
127
+
128
+ for i in range(input_layers - 1):
129
+ proj_weight.append(nn.SiLU())
130
+ proj_weight.append(nn.Linear(d_edge, d_edge))
131
+
132
+ self.proj_weight = nn.Sequential(*proj_weight)
133
+ if not self.zero_out_bias:
134
+ proj_bias = []
135
+ if sin_emb:
136
+ proj_bias.append(
137
+ GaussianEncoding(
138
+ sigma=inp_factor,
139
+ input_size=d_in,
140
+ encoded_size=sin_emb_dim,
141
+ )
142
+ )
143
+ proj_bias.append(nn.Linear(2 * sin_emb_dim, d_node))
144
+ else:
145
+ proj_bias.append(nn.Linear(d_in, d_node))
146
+
147
+ for i in range(input_layers - 1):
148
+ proj_bias.append(nn.SiLU())
149
+ proj_bias.append(nn.Linear(d_node, d_node))
150
+
151
+ self.proj_bias = nn.Sequential(*proj_bias)
152
+
153
+ self.proj_node_in = nn.Linear(d_node, d_node)
154
+ self.proj_edge_in = nn.Linear(d_edge, d_edge)
155
+
156
+ if num_probe_features > 0:
157
+ self.gpf = GraphProbeFeatures(
158
+ d_in=layer_layout[0],
159
+ num_inputs=num_probe_features,
160
+ inr_model=inr_model,
161
+ input_init=None,
162
+ proj_dim=d_node,
163
+ )
164
+ else:
165
+ self.gpf = None
166
+
167
+ def forward(self, inputs):
168
+ node_features, edge_features = batch_to_graphs(*inputs, **self.stats,
169
+ )
170
+ mask = edge_features.sum(dim=-1, keepdim=True) != 0
171
+ if self.rev_edge_features:
172
+ rev_edge_features = edge_features.transpose(-2, -3)
173
+ edge_features = torch.cat(
174
+ [edge_features, rev_edge_features, edge_features + rev_edge_features],
175
+ dim=-1,
176
+ )
177
+ mask = mask | mask.transpose(-3, -2)
178
+
179
+ if self.zero_out_weights:
180
+ edge_features = torch.zeros(
181
+ (*edge_features.shape[:-1], self._d_edge),
182
+ device=edge_features.device,
183
+ dtype=edge_features.dtype,
184
+ )
185
+ else:
186
+ edge_features = self.proj_weight(edge_features)
187
+ if self.zero_out_bias:
188
+ # only zero out bias, not gpf
189
+ node_features = torch.zeros(
190
+ (*node_features.shape[:-1], self._d_node),
191
+ device=node_features.device,
192
+ dtype=node_features.dtype,
193
+ )
194
+ else:
195
+ node_features = self.proj_bias(node_features)
196
+
197
+ if self.gpf is not None:
198
+ probe_features = self.gpf(*inputs)
199
+ node_features = node_features + probe_features
200
+
201
+ node_features = self.proj_node_in(node_features)
202
+ edge_features = self.proj_edge_in(edge_features)
203
+
204
+ if self.use_pos_embed:
205
+ pos_embed = torch.cat(
206
+ [
207
+ # repeat(self.pos_embed[i], "d -> 1 n d", n=n)
208
+ self.pos_embed[i].unsqueeze(0).expand(1, n, -1)
209
+ for i, n in enumerate(self.pos_embed_layout)
210
+ ],
211
+ dim=1,
212
+ )
213
+ node_features = node_features + pos_embed
214
+ return node_features, edge_features, mask
tasks/utils/inr.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from rff.layers import GaussianEncoding, PositionalEncoding
8
+ from torch import nn
9
+ from .kan.fasterkan import FasterKAN
10
+
11
+
12
+
13
+ class Sine(nn.Module):
14
+ def __init__(self, w0=1.0):
15
+ super().__init__()
16
+ self.w0 = w0
17
+
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ return torch.sin(self.w0 * x)
20
+
21
+
22
+ def params_to_tensor(params):
23
+ return torch.cat([p.flatten() for p in params]), [p.shape for p in params]
24
+
25
+
26
+ def tensor_to_params(tensor, shapes):
27
+ params = []
28
+ start = 0
29
+ for shape in shapes:
30
+ size = torch.prod(torch.tensor(shape)).item()
31
+ param = tensor[start : start + size].reshape(shape)
32
+ params.append(param)
33
+ start += size
34
+ return tuple(params)
35
+
36
+
37
+ def wrap_func(func, shapes):
38
+ def wrapped_func(params, *args, **kwargs):
39
+ params = tensor_to_params(params, shapes)
40
+ return func(params, *args, **kwargs)
41
+
42
+ return wrapped_func
43
+
44
+
45
+ class Siren(nn.Module):
46
+ def __init__(
47
+ self,
48
+ dim_in,
49
+ dim_out,
50
+ w0=30.0,
51
+ c=6.0,
52
+ is_first=False,
53
+ use_bias=True,
54
+ activation=None,
55
+ ):
56
+ super().__init__()
57
+ self.w0 = w0
58
+ self.c = c
59
+ self.dim_in = dim_in
60
+ self.dim_out = dim_out
61
+ self.is_first = is_first
62
+
63
+ weight = torch.zeros(dim_out, dim_in)
64
+ bias = torch.zeros(dim_out) if use_bias else None
65
+ self.init_(weight, bias, c=c, w0=w0)
66
+
67
+ self.weight = nn.Parameter(weight)
68
+ self.bias = nn.Parameter(bias) if use_bias else None
69
+ self.activation = Sine(w0) if activation is None else activation
70
+
71
+ def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
72
+ dim = self.dim_in
73
+
74
+ w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
75
+ weight.uniform_(-w_std, w_std)
76
+
77
+ if bias is not None:
78
+ # bias.uniform_(-w_std, w_std)
79
+ bias.zero_()
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ out = F.linear(x, self.weight, self.bias)
83
+ out = self.activation(out)
84
+ return out
85
+
86
+
87
+ class INR(nn.Module):
88
+ def __init__(
89
+ self,
90
+ in_features: int = 2,
91
+ n_layers: int = 3,
92
+ hidden_features: int = 32,
93
+ out_features: int = 1,
94
+ pe_features: Optional[int] = None,
95
+ fix_pe=True,
96
+ ):
97
+ super().__init__()
98
+
99
+ if pe_features is not None:
100
+ if fix_pe:
101
+ self.layers = [PositionalEncoding(sigma=10, m=pe_features)]
102
+ encoded_dim = in_features * pe_features * 2
103
+ else:
104
+ self.layers = [
105
+ GaussianEncoding(
106
+ sigma=10, input_size=in_features, encoded_size=pe_features
107
+ )
108
+ ]
109
+ encoded_dim = pe_features * 2
110
+ self.layers.append(Siren(dim_in=encoded_dim, dim_out=hidden_features))
111
+ else:
112
+ self.layers = [Siren(dim_in=in_features, dim_out=hidden_features)]
113
+ for i in range(n_layers - 2):
114
+ self.layers.append(Siren(hidden_features, hidden_features))
115
+ self.layers.append(nn.Linear(hidden_features, out_features))
116
+ self.seq = nn.Sequential(*self.layers)
117
+ self.num_layers = len(self.layers)
118
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
119
+ return self.seq(x) + 0.5
120
+
121
+
122
+ class INRPerLayer(INR):
123
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
124
+ nodes = [x]
125
+ for layer in self.seq:
126
+ nodes.append(layer(nodes[-1]))
127
+ nodes[-1] = nodes[-1] + 0.5
128
+ return nodes
129
+
130
+
131
+ def make_functional(mod, disable_autograd_tracking=False):
132
+ params_dict = dict(mod.named_parameters())
133
+ params_names = params_dict.keys()
134
+ params_values = tuple(params_dict.values())
135
+
136
+ stateless_mod = copy.deepcopy(mod)
137
+ stateless_mod.to("meta")
138
+
139
+ def fmodel(new_params_values, *args, **kwargs):
140
+ new_params_dict = {
141
+ name: value for name, value in zip(params_names, new_params_values)
142
+ }
143
+ return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
144
+
145
+ if disable_autograd_tracking:
146
+ params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
147
+ return fmodel, params_values
tasks/utils/models.py CHANGED
@@ -4,11 +4,56 @@ from .Modules.conformer import ConformerEncoder, ConformerDecoder
4
  from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
5
  from .kan.fasterkan import FasterKAN
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class ConvBlock(nn.Module):
8
  def __init__(self, args, num_layer) -> None:
9
  super().__init__()
10
  if args.activation == 'silu':
11
  self.activation = nn.SiLU()
 
 
12
  else:
13
  self.activation = nn.ReLU()
14
  in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
@@ -31,6 +76,8 @@ class CNNEncoder(nn.Module):
31
  print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
32
  if args.activation == 'silu':
33
  self.activation = nn.SiLU()
 
 
34
  else:
35
  self.activation = nn.ReLU()
36
  self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
@@ -125,6 +172,21 @@ class CNNKan(nn.Module):
125
  x = x.mean(dim=1)
126
  return self.kan(x)
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  class KanEncoder(nn.Module):
129
  def __init__(self, args):
130
  super().__init__()
@@ -138,3 +200,44 @@ class KanEncoder(nn.Module):
138
  out = torch.cat([x, f], dim=-1)
139
  return self.kan_out(out)
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from .Modules.mhsa_pro import RotaryEmbedding, ContinuousRotaryEmbedding
5
  from .kan.fasterkan import FasterKAN
6
 
7
+
8
+ class Sine(nn.Module):
9
+ def __init__(self, w0=1.0):
10
+ super().__init__()
11
+ self.w0 = w0
12
+
13
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
14
+ return torch.sin(self.w0 * x)
15
+
16
+
17
+ class MLPEncoder(nn.Module):
18
+ def __init__(self, args):
19
+ """
20
+ Initialize an MLP with hidden layers, BatchNorm, and Dropout.
21
+
22
+ Args:
23
+ input_dim (int): Dimension of the input features.
24
+ hidden_dims (list of int): List of dimensions for hidden layers.
25
+ output_dim (int): Dimension of the output.
26
+ dropout (float): Dropout probability (default: 0.0).
27
+ """
28
+ super(MLPEncoder, self).__init__()
29
+
30
+ layers = []
31
+ prev_dim = args.input_dim
32
+
33
+ # Add hidden layers
34
+ for hidden_dim in args.hidden_dims:
35
+ layers.append(nn.Linear(prev_dim, hidden_dim))
36
+ layers.append(nn.BatchNorm1d(hidden_dim))
37
+ layers.append(nn.SiLU())
38
+ if args.dropout > 0.0:
39
+ layers.append(nn.Dropout(args.dropout))
40
+ prev_dim = hidden_dim
41
+ self.model = nn.Sequential(*layers)
42
+ self.output_dim = hidden_dim
43
+
44
+ def forward(self, x):
45
+ # if x.dim() == 2:
46
+ # x = x.unsqueeze(-1)
47
+ x = self.model(x)
48
+ # x = x.mean(-1)
49
+ return x
50
  class ConvBlock(nn.Module):
51
  def __init__(self, args, num_layer) -> None:
52
  super().__init__()
53
  if args.activation == 'silu':
54
  self.activation = nn.SiLU()
55
+ elif args.activation == 'sine':
56
+ self.activation = Sine(w0=args.sine_w0)
57
  else:
58
  self.activation = nn.ReLU()
59
  in_channels = args.encoder_dims[num_layer-1] if num_layer < len(args.encoder_dims) else args.encoder_dims[-1]
 
76
  print("Using CNN encoder wit activation: ", args.activation, 'args avg_output: ', args.avg_output)
77
  if args.activation == 'silu':
78
  self.activation = nn.SiLU()
79
+ elif args.activation == 'sine':
80
+ self.activation = Sine(w0=args.sine_w0)
81
  else:
82
  self.activation = nn.ReLU()
83
  self.embedding = nn.Sequential(nn.Conv1d(in_channels = args.in_channels,
 
172
  x = x.mean(dim=1)
173
  return self.kan(x)
174
 
175
+ class CNNKanFeaturesEncoder(nn.Module):
176
+ def __init__(self, args, mlp_args, kan_args):
177
+ super().__init__()
178
+ self.backbone = CNNEncoder(args)
179
+ self.mlp = MLPEncoder(mlp_args)
180
+ kan_args['layers_hidden'][0] += self.mlp.output_dim
181
+ self.kan = FasterKAN(**kan_args)
182
+
183
+ def forward(self, x: torch.Tensor, f: torch.Tensor) -> torch.Tensor:
184
+ x = self.backbone(x)
185
+ x = x.mean(dim=1)
186
+ f = self.mlp(f)
187
+ x_f = torch.cat([x, f], dim=-1)
188
+ return self.kan(x_f)
189
+
190
  class KanEncoder(nn.Module):
191
  def __init__(self, args):
192
  super().__init__()
 
200
  out = torch.cat([x, f], dim=-1)
201
  return self.kan_out(out)
202
 
203
+
204
+ class MultiGraph(nn.Module):
205
+ def __init__(self, graph_net, args):
206
+ super().__init__()
207
+ self.graph_net = graph_net
208
+ self.cnn = CNNEncoder(args)
209
+ total_output_dim = args.encoder_dims[-1]
210
+ self.projection = nn.Sequential(
211
+ nn.Linear(total_output_dim, total_output_dim // 2),
212
+ nn.BatchNorm1d(total_output_dim // 2),
213
+ nn.SiLU(),
214
+ nn.Linear(total_output_dim // 2, 1)
215
+ )
216
+
217
+ def forward(self, g: torch.Tensor, x:torch.Tensor) -> torch.Tensor:
218
+ # g_out = self.graph_net(g)
219
+ x_out = self.cnn(x)
220
+ # g_out = g_out.expand(x.shape[0], -1)
221
+ # features = torch.cat([g_out, x_out], dim=-1)
222
+ return self.projection(x_out)
223
+
224
+ class ImplicitEncoder(nn.Module):
225
+ def __init__(self, transform_net, encoder_net):
226
+ super().__init__()
227
+ self.transform_net = transform_net
228
+ self.encoder_net = encoder_net
229
+
230
+ def get_weights_and_bises(self):
231
+ state_dict = self.transform_net.state_dict()
232
+ weights = tuple(
233
+ [v.permute(1, 0).unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "weight" in w]
234
+ )
235
+ biases = tuple([v.unsqueeze(-1).unsqueeze(0) for w, v in state_dict.items() if "bias" in w])
236
+ return weights, biases
237
+
238
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
239
+ transformed_x = self.transform_net(x.permute(0, 2, 1)).permute(0, 2, 1)
240
+ inputs = self.get_weights_and_bises()
241
+ outputs = self.encoder_net(inputs, transformed_x)
242
+ return outputs
243
+
tasks/utils/pooling.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn.aggr import (
4
+ AttentionalAggregation,
5
+ GraphMultisetTransformer,
6
+ MaxAggregation,
7
+ MeanAggregation,
8
+ SetTransformerAggregation,
9
+ )
10
+
11
+ class CatAggregation(nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.flatten = nn.Flatten(1, 2)
15
+
16
+ def forward(self, x, index=None):
17
+ return self.flatten(x)
18
+
19
+
20
+ class HeterogeneousAggregator(nn.Module):
21
+ def __init__(
22
+ self,
23
+ input_dim,
24
+ hidden_dim,
25
+ output_dim,
26
+ pooling_method,
27
+ pooling_layer_idx,
28
+ input_channels,
29
+ num_classes,
30
+ ):
31
+ super().__init__()
32
+ self.pooling_method = pooling_method
33
+ self.pooling_layer_idx = pooling_layer_idx
34
+ self.input_channels = input_channels
35
+ self.num_classes = num_classes
36
+
37
+ if pooling_layer_idx == "all":
38
+ self._pool_layer_idx_fn = self.get_all_layer_indices
39
+ elif pooling_layer_idx == "last":
40
+ self._pool_layer_idx_fn = self.get_last_layer_indices
41
+ elif isinstance(pooling_layer_idx, int):
42
+ self._pool_layer_idx_fn = self.get_nth_layer_indices
43
+ else:
44
+ raise ValueError(f"Unknown pooling layer index {pooling_layer_idx}")
45
+
46
+ if pooling_method == "mean":
47
+ self.pool = MeanAggregation()
48
+ elif pooling_method == "max":
49
+ self.pool = MaxAggregation()
50
+ elif pooling_method == "cat":
51
+ self.pool = CatAggregation()
52
+ elif pooling_method == "attentional_aggregation":
53
+ self.pool = AttentionalAggregation(
54
+ gate_nn=nn.Sequential(
55
+ nn.Linear(input_dim, hidden_dim),
56
+ nn.SiLU(),
57
+ nn.Linear(hidden_dim, 1),
58
+ ),
59
+ nn=nn.Sequential(
60
+ nn.Linear(input_dim, hidden_dim),
61
+ nn.SiLU(),
62
+ nn.Linear(hidden_dim, output_dim),
63
+ ),
64
+ )
65
+ elif pooling_method == "set_transformer":
66
+ self.pool = SetTransformerAggregation(
67
+ input_dim, heads=8, num_encoder_blocks=4, num_decoder_blocks=4
68
+ )
69
+ elif pooling_method == "graph_multiset_transformer":
70
+ self.pool = GraphMultisetTransformer(input_dim, k=8, heads=8)
71
+ else:
72
+ raise ValueError(f"Unknown pooling method {pooling_method}")
73
+
74
+ def get_last_layer_indices(
75
+ self, x, layer_layouts, node_mask=None, return_dense=False
76
+ ):
77
+ batch_size = x.shape[0]
78
+ device = x.device
79
+
80
+ # NOTE: node_mask needs to exist in the heterogeneous case only
81
+ if node_mask is None:
82
+ node_mask = torch.ones_like(x[..., 0], dtype=torch.bool, device=device)
83
+
84
+ valid_layer_indices = (
85
+ torch.arange(node_mask.shape[1], device=device)[None, :] * node_mask
86
+ )
87
+ last_layer_indices = valid_layer_indices.topk(
88
+ k=self.num_classes, dim=1
89
+ ).values.fliplr()
90
+
91
+ if return_dense:
92
+ return torch.arange(batch_size, device=device)[:, None], last_layer_indices
93
+
94
+ batch_indices = torch.arange(batch_size, device=device).repeat_interleave(
95
+ self.num_classes
96
+ )
97
+ return batch_indices, last_layer_indices.flatten()
98
+
99
+ def get_nth_layer_indices(
100
+ self, x, layer_layouts, node_mask=None, return_dense=False
101
+ ):
102
+ batch_size = x.shape[0]
103
+ device = x.device
104
+
105
+ cum_layer_layout = [
106
+ torch.cumsum(torch.tensor([0] + layer_layout), dim=0)
107
+ for layer_layout in layer_layouts
108
+ ]
109
+
110
+ layer_sizes = torch.tensor(
111
+ [layer_layout[self.pooling_layer_idx] for layer_layout in layer_layouts],
112
+ dtype=torch.long,
113
+ device=device,
114
+ )
115
+ batch_indices = torch.arange(batch_size, device=device).repeat_interleave(
116
+ layer_sizes
117
+ )
118
+ layer_indices = torch.cat(
119
+ [
120
+ torch.arange(
121
+ layout[self.pooling_layer_idx],
122
+ layout[self.pooling_layer_idx + 1],
123
+ device=device,
124
+ )
125
+ for layout in cum_layer_layout
126
+ ]
127
+ )
128
+ return batch_indices, layer_indices
129
+
130
+ def get_all_layer_indices(
131
+ self, x, layer_layouts, node_mask=None, return_dense=False
132
+ ):
133
+ """Imitate flattening with indexing"""
134
+ batch_size, num_nodes = x.shape[:2]
135
+ device = x.device
136
+ batch_indices = torch.arange(batch_size, device=device).repeat_interleave(
137
+ num_nodes
138
+ )
139
+ layer_indices = torch.arange(num_nodes, device=device).repeat(batch_size)
140
+ return batch_indices, layer_indices
141
+
142
+ def forward(self, x, layer_layouts, node_mask=None):
143
+ # NOTE: `cat` only works with `pooling_layer_idx == "last"`
144
+ return_dense = self.pooling_method == "cat" and self.pooling_layer_idx == "last"
145
+ batch_indices, layer_indices = self._pool_layer_idx_fn(
146
+ x, layer_layouts, node_mask=node_mask, return_dense=return_dense
147
+ )
148
+
149
+ flat_x = x[batch_indices, layer_indices]
150
+ return self.pool(flat_x, index=batch_indices)
151
+
152
+
153
+ class HomogeneousAggregator(nn.Module):
154
+ def __init__(
155
+ self,
156
+ pooling_method,
157
+ pooling_layer_idx,
158
+ layer_layout,
159
+ ):
160
+ super().__init__()
161
+ self.pooling_method = pooling_method
162
+ self.pooling_layer_idx = pooling_layer_idx
163
+ self.layer_layout = layer_layout
164
+
165
+ def forward(self, node_features, edge_features):
166
+ if self.pooling_method == "mean" and self.pooling_layer_idx == "all":
167
+ graph_features = node_features.mean(dim=1)
168
+ elif self.pooling_method == "max" and self.pooling_layer_idx == "all":
169
+ graph_features = node_features.max(dim=1).values
170
+ elif self.pooling_method == "mean" and self.pooling_layer_idx == "last":
171
+ graph_features = node_features[:, -self.layer_layout[-1] :].mean(dim=1)
172
+ elif self.pooling_method == "cat" and self.pooling_layer_idx == "last":
173
+ graph_features = node_features[:, -self.layer_layout[-1] :].flatten(1, 2)
174
+ elif self.pooling_method == "mean" and isinstance(self.pooling_layer_idx, int):
175
+ graph_features = node_features[
176
+ :,
177
+ self.layer_idx[self.pooling_layer_idx] : self.layer_idx[
178
+ self.pooling_layer_idx + 1
179
+ ],
180
+ ].mean(dim=1)
181
+ elif self.pooling_method == "cat_mean" and self.pooling_layer_idx == "all":
182
+ graph_features = torch.cat(
183
+ [
184
+ node_features[:, self.layer_idx[i] : self.layer_idx[i + 1]].mean(
185
+ dim=1
186
+ )
187
+ for i in range(len(self.layer_layout))
188
+ ],
189
+ dim=1,
190
+ )
191
+ elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "all":
192
+ graph_features = edge_features.mean(dim=(1, 2))
193
+ elif self.pooling_method == "max_edge" and self.pooling_layer_idx == "all":
194
+ graph_features = edge_features.flatten(1, 2).max(dim=1).values
195
+ elif self.pooling_method == "mean_edge" and self.pooling_layer_idx == "last":
196
+ graph_features = edge_features[:, :, -self.layer_layout[-1] :].mean(
197
+ dim=(1, 2)
198
+ )
199
+ return graph_features
tasks/utils/probe_features.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops.layers.torch import Rearrange
5
+
6
+ from nn.inr import make_functional, params_to_tensor, wrap_func
7
+
8
+
9
+ class GraphProbeFeatures(nn.Module):
10
+ def __init__(self, d_in, num_inputs, inr_model, input_init=None, proj_dim=None):
11
+ super().__init__()
12
+ inr = hydra.utils.instantiate(inr_model)
13
+ fmodel, params = make_functional(inr)
14
+
15
+ vparams, vshapes = params_to_tensor(params)
16
+ self.sirens = torch.vmap(wrap_func(fmodel, vshapes))
17
+
18
+ inputs = (
19
+ input_init
20
+ if input_init is not None
21
+ else 2 * torch.rand(1, num_inputs, d_in) - 1
22
+ )
23
+ self.inputs = nn.Parameter(inputs, requires_grad=input_init is None)
24
+
25
+ self.reshape_weights = Rearrange("b i o 1 -> b (o i)")
26
+ self.reshape_biases = Rearrange("b o 1 -> b o")
27
+
28
+ self.proj_dim = proj_dim
29
+ if proj_dim is not None:
30
+ self.proj = nn.ModuleList(
31
+ [
32
+ nn.Sequential(
33
+ nn.Linear(num_inputs, proj_dim),
34
+ nn.LayerNorm(proj_dim),
35
+ )
36
+ for _ in range(inr.num_layers + 1)
37
+ ]
38
+ )
39
+
40
+ def forward(self, weights, biases):
41
+ weights = [self.reshape_weights(w) for w in weights]
42
+ biases = [self.reshape_biases(b) for b in biases]
43
+ params_flat = torch.cat(
44
+ [w_or_b for p in zip(weights, biases) for w_or_b in p], dim=-1
45
+ )
46
+
47
+ out = self.sirens(params_flat, self.inputs.expand(params_flat.shape[0], -1, -1))
48
+ if self.proj_dim is not None:
49
+ out = [proj(out[i].permute(0, 2, 1)) for i, proj in enumerate(self.proj)]
50
+ out = torch.cat(out, dim=1)
51
+ return out
52
+ else:
53
+ out = torch.cat(out, dim=-1)
54
+ return out.permute(0, 2, 1)
tasks/utils/relational_transformer.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from einops.layers.torch import Rearrange
5
+
6
+ from utils.pooling import HomogeneousAggregator
7
+ import torch.nn as nn
8
+
9
+
10
+ class RelationalTransformer(nn.Module):
11
+ def __init__(
12
+ self,
13
+ d_node,
14
+ d_edge,
15
+ d_attn_hid,
16
+ d_node_hid,
17
+ d_edge_hid,
18
+ d_out_hid,
19
+ d_out,
20
+ n_layers,
21
+ n_heads,
22
+ layer_layout,
23
+ graph_constructor,
24
+ dropout=0.0,
25
+ node_update_type="rt",
26
+ disable_edge_updates=False,
27
+ use_cls_token=False,
28
+ pooling_method="cat",
29
+ pooling_layer_idx="last",
30
+ rev_edge_features=False,
31
+ modulate_v=True,
32
+ use_ln=True,
33
+ tfixit_init=False,
34
+ ):
35
+ super().__init__()
36
+ assert use_cls_token == (pooling_method == "cls_token")
37
+ self.pooling_method = pooling_method
38
+ self.pooling_layer_idx = pooling_layer_idx
39
+ self.rev_edge_features = rev_edge_features
40
+ self.nodes_per_layer = layer_layout
41
+ self.construct_graph = hydra.utils.instantiate(
42
+ graph_constructor,
43
+ d_node=d_node,
44
+ d_edge=d_edge,
45
+ layer_layout=layer_layout,
46
+ rev_edge_features=rev_edge_features,
47
+ )
48
+
49
+ self.use_cls_token = use_cls_token
50
+ if use_cls_token:
51
+ self.cls_token = nn.Parameter(torch.randn(d_node))
52
+
53
+ self.layers = nn.ModuleList(
54
+ [
55
+ torch.jit.script(
56
+ RTLayer(
57
+ d_node,
58
+ d_edge,
59
+ d_attn_hid,
60
+ d_node_hid,
61
+ d_edge_hid,
62
+ n_heads,
63
+ dropout,
64
+ node_update_type=node_update_type,
65
+ disable_edge_updates=(
66
+ (disable_edge_updates or (i == n_layers - 1))
67
+ and pooling_method != "mean_edge"
68
+ and pooling_layer_idx != "all"
69
+ ),
70
+ modulate_v=modulate_v,
71
+ use_ln=use_ln,
72
+ tfixit_init=tfixit_init,
73
+ n_layers=n_layers,
74
+ )
75
+ )
76
+ for i in range(n_layers)
77
+ ]
78
+ )
79
+
80
+ if pooling_method != "cls_token":
81
+ self.pool = HomogeneousAggregator(
82
+ pooling_method,
83
+ pooling_layer_idx,
84
+ layer_layout,
85
+ )
86
+
87
+ self.num_graph_features = (
88
+ layer_layout[-1] * d_node
89
+ if pooling_method == "cat" and pooling_layer_idx == "last"
90
+ else d_edge if pooling_method in ("mean_edge", "max_edge") else d_node
91
+ )
92
+ self.proj_out = nn.Sequential(
93
+ nn.Linear(self.num_graph_features, d_out_hid),
94
+ nn.ReLU(),
95
+ # nn.Linear(d_out_hid, d_out_hid),
96
+ # nn.ReLU(),
97
+ nn.Linear(d_out_hid, d_out),
98
+ )
99
+
100
+ self.final_features = (None,None,None,None)
101
+
102
+ def forward(self, inputs):
103
+ attn_weights = None
104
+ node_features, edge_features, mask = self.construct_graph(inputs)
105
+ if self.use_cls_token:
106
+ node_features = torch.cat(
107
+ [
108
+ # repeat(self.cls_token, "d -> b 1 d", b=node_features.size(0)),
109
+ self.cls_token.unsqueeze(0).expand(node_features.size(0), 1, -1),
110
+ node_features,
111
+ ],
112
+ dim=1,
113
+ )
114
+ edge_features = F.pad(edge_features, (0, 0, 1, 0, 1, 0), value=0)
115
+ for layer in self.layers:
116
+ node_features, edge_features, attn_weights = layer(node_features, edge_features, mask)
117
+
118
+ if self.pooling_method == "cls_token":
119
+ graph_features = node_features[:, 0]
120
+ else:
121
+ graph_features = self.pool(node_features, edge_features)
122
+ self.final_features = (graph_features, node_features, edge_features, attn_weights)
123
+ return self.proj_out(graph_features)
124
+
125
+
126
+ class RTLayer(nn.Module):
127
+ def __init__(
128
+ self,
129
+ d_node,
130
+ d_edge,
131
+ d_attn_hid,
132
+ d_node_hid,
133
+ d_edge_hid,
134
+ n_heads,
135
+ dropout,
136
+ node_update_type="rt",
137
+ disable_edge_updates=False,
138
+ modulate_v=True,
139
+ use_ln=True,
140
+ tfixit_init=False,
141
+ n_layers=None,
142
+ ):
143
+ super().__init__()
144
+ self.node_update_type = node_update_type
145
+ self.disable_edge_updates = disable_edge_updates
146
+ self.use_ln = use_ln
147
+ self.n_layers = n_layers
148
+
149
+ self.self_attn = torch.jit.script(
150
+ RTAttention(
151
+ d_node,
152
+ d_edge,
153
+ d_attn_hid,
154
+ n_heads,
155
+ modulate_v=modulate_v,
156
+ use_ln=use_ln,
157
+ )
158
+ )
159
+ # self.self_attn = RTAttention(d_hid, d_hid, d_hid, n_heads)
160
+ self.lin0 = Linear(d_node, d_node)
161
+ self.dropout0 = nn.Dropout(dropout)
162
+ if use_ln:
163
+ self.node_ln0 = nn.LayerNorm(d_node)
164
+ self.node_ln1 = nn.LayerNorm(d_node)
165
+ else:
166
+ self.node_ln0 = nn.Identity()
167
+ self.node_ln1 = nn.Identity()
168
+
169
+ act_fn = nn.GELU
170
+
171
+ self.node_mlp = nn.Sequential(
172
+ Linear(d_node, d_node_hid, bias=False),
173
+ act_fn(),
174
+ Linear(d_node_hid, d_node),
175
+ nn.Dropout(dropout),
176
+ )
177
+
178
+ if not self.disable_edge_updates:
179
+ self.edge_updates = EdgeLayer(
180
+ d_node=d_node,
181
+ d_edge=d_edge,
182
+ d_edge_hid=d_edge_hid,
183
+ dropout=dropout,
184
+ act_fn=act_fn,
185
+ use_ln=use_ln,
186
+ )
187
+ else:
188
+ self.edge_updates = NoEdgeLayer()
189
+
190
+ if tfixit_init:
191
+ self.fixit_init()
192
+
193
+ def fixit_init(self):
194
+ temp_state_dict = self.state_dict()
195
+ n_layers = self.n_layers
196
+ for name, param in self.named_parameters():
197
+ if "weight" in name:
198
+ if name.split(".")[0] in ["node_mlp", "edge_mlp0", "edge_mlp1"]:
199
+ temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * param
200
+ elif name.split(".")[0] in ["self_attn"]:
201
+ temp_state_dict[name] = (0.67 * (n_layers) ** (-1.0 / 4.0)) * (
202
+ param * (2**0.5)
203
+ )
204
+
205
+ self.load_state_dict(temp_state_dict)
206
+
207
+ def node_updates(self, node_features, edge_features, mask):
208
+ out = self.self_attn(node_features, edge_features, mask)
209
+ attn_out, attn_weights = out
210
+ node_features = self.node_ln0(
211
+ node_features
212
+ + self.dropout0(
213
+ self.lin0(attn_out)
214
+ )
215
+ )
216
+ node_features = self.node_ln1(node_features + self.node_mlp(node_features))
217
+
218
+ return node_features, attn_weights
219
+
220
+ def forward(self, node_features, edge_features, mask):
221
+ node_features, attn_weights = self.node_updates(node_features, edge_features, mask)
222
+ edge_features = self.edge_updates(node_features, edge_features, mask)
223
+
224
+ return node_features, edge_features, attn_weights
225
+
226
+
227
+ class EdgeLayer(nn.Module):
228
+ def __init__(
229
+ self,
230
+ *,
231
+ d_node,
232
+ d_edge,
233
+ d_edge_hid,
234
+ dropout,
235
+ act_fn,
236
+ use_ln=True,
237
+ ) -> None:
238
+ super().__init__()
239
+ self.edge_mlp0 = EdgeMLP(
240
+ d_edge=d_edge,
241
+ d_node=d_node,
242
+ d_edge_hid=d_edge_hid,
243
+ act_fn=act_fn,
244
+ dropout=dropout,
245
+ )
246
+ self.edge_mlp1 = nn.Sequential(
247
+ Linear(d_edge, d_edge_hid, bias=False),
248
+ act_fn(),
249
+ Linear(d_edge_hid, d_edge),
250
+ nn.Dropout(dropout),
251
+ )
252
+ if use_ln:
253
+ self.eln0 = nn.LayerNorm(d_edge)
254
+ self.eln1 = nn.LayerNorm(d_edge)
255
+ else:
256
+ self.eln0 = nn.Identity()
257
+ self.eln1 = nn.Identity()
258
+
259
+ def forward(self, node_features, edge_features, mask):
260
+ edge_features = self.eln0(
261
+ edge_features + self.edge_mlp0(node_features, edge_features)
262
+ )
263
+ edge_features = self.eln1(edge_features + self.edge_mlp1(edge_features))
264
+ return edge_features
265
+
266
+
267
+ class NoEdgeLayer(nn.Module):
268
+ def forward(self, node_features, edge_features, mask):
269
+ return edge_features
270
+
271
+
272
+ class EdgeMLP(nn.Module):
273
+ def __init__(self, *, d_node, d_edge, d_edge_hid, act_fn, dropout):
274
+ super().__init__()
275
+ self.reverse_edge = Rearrange("b n m d -> b m n d")
276
+ self.lin0_e = Linear(2 * d_edge, d_edge_hid)
277
+ self.lin0_s = Linear(d_node, d_edge_hid)
278
+ self.lin0_t = Linear(d_node, d_edge_hid)
279
+ self.act = act_fn()
280
+ self.lin1 = Linear(d_edge_hid, d_edge)
281
+ self.drop = nn.Dropout(dropout)
282
+
283
+ def forward(self, node_features, edge_features):
284
+ source_nodes = (
285
+ self.lin0_s(node_features)
286
+ .unsqueeze(-2)
287
+ .expand(-1, -1, node_features.size(-2), -1)
288
+ )
289
+ target_nodes = (
290
+ self.lin0_t(node_features)
291
+ .unsqueeze(-3)
292
+ .expand(-1, node_features.size(-2), -1, -1)
293
+ )
294
+
295
+ # reversed_edge_features = self.reverse_edge(edge_features)
296
+ edge_features = self.lin0_e(
297
+ torch.cat([edge_features, self.reverse_edge(edge_features)], dim=-1)
298
+ )
299
+ edge_features = edge_features + source_nodes + target_nodes
300
+ edge_features = self.act(edge_features)
301
+ edge_features = self.lin1(edge_features)
302
+ edge_features = self.drop(edge_features)
303
+
304
+ return edge_features
305
+
306
+
307
+ class RTAttention(nn.Module):
308
+ def __init__(self, d_node, d_edge, d_hid, n_heads, modulate_v=None, use_ln=True):
309
+ super().__init__()
310
+ self.n_heads = n_heads
311
+ self.d_node = d_node
312
+ self.d_edge = d_edge
313
+ self.d_hid = d_hid
314
+ self.use_ln = use_ln
315
+ self.modulate_v = modulate_v
316
+ self.scale = 1 / (d_hid**0.5)
317
+ self.split_head_node = Rearrange("b n (h d) -> b h n d", h=n_heads)
318
+ self.split_head_edge = Rearrange("b n m (h d) -> b h n m d", h=n_heads)
319
+ self.cat_head_node = Rearrange("... h n d -> ... n (h d)", h=n_heads)
320
+
321
+ self.qkv_node = Linear(d_node, 3 * d_hid, bias=False)
322
+ self.edge_factor = 4 if modulate_v else 3
323
+ self.qkv_edge = Linear(d_edge, self.edge_factor * d_hid, bias=False)
324
+ self.proj_out = Linear(d_hid, d_node)
325
+
326
+ def forward(self, node_features, edge_features, mask):
327
+ qkv_node = self.qkv_node(node_features)
328
+ # qkv_node = rearrange(qkv_node, "b n (h d) -> b h n d", h=self.n_heads)
329
+ qkv_node = self.split_head_node(qkv_node)
330
+ q_node, k_node, v_node = torch.chunk(qkv_node, 3, dim=-1)
331
+
332
+ qkv_edge = self.qkv_edge(edge_features)
333
+ # qkv_edge = rearrange(qkv_edge, "b n m (h d) -> b h n m d", h=self.n_heads)
334
+ qkv_edge = self.split_head_edge(qkv_edge)
335
+ qkv_edge = torch.chunk(qkv_edge, self.edge_factor, dim=-1)
336
+ # q_edge, k_edge, v_edge, q_edge_b, k_edge_b, v_edge_b = torch.chunk(
337
+ # qkv_edge, 6, dim=-1
338
+ # )
339
+ # qkv_edge = [item.masked_fill(mask.unsqueeze(1) == 0, 0) for item in qkv_edge]
340
+
341
+ q = q_node.unsqueeze(-2) + qkv_edge[0] # + q_edge_b
342
+ k = k_node.unsqueeze(-3) + qkv_edge[1] # + k_edge_b
343
+ if self.modulate_v:
344
+ v = v_node.unsqueeze(-3) * qkv_edge[3] + qkv_edge[2]
345
+ else:
346
+ v = v_node.unsqueeze(-3) + qkv_edge[2]
347
+ dots = self.scale * torch.einsum("b h i j d, b h i j d -> b h i j", q, k)
348
+ # dots.masked_fill_(mask.unsqueeze(1).squeeze(-1) == 0, -1e-9)
349
+
350
+ attn = F.softmax(dots, dim=-1)
351
+ out = torch.einsum("b h i j, b h i j d -> b h i d", attn, v)
352
+ out = self.cat_head_node(out)
353
+ return self.proj_out(out), attn
354
+
355
+
356
+ def Linear(in_features, out_features, bias=True):
357
+ m = nn.Linear(in_features, out_features, bias)
358
+ nn.init.xavier_uniform_(m.weight) # , gain=1 / math.sqrt(2))
359
+ if bias:
360
+ nn.init.constant_(m.bias, 0.0)
361
+ return m
tasks/utils/train.py CHANGED
@@ -9,6 +9,14 @@ import glob
9
  from collections import OrderedDict
10
  from tqdm import tqdm
11
  import torch.distributed as dist
 
 
 
 
 
 
 
 
12
 
13
  class Trainer(object):
14
  """
@@ -217,9 +225,12 @@ class Trainer(object):
217
  return train_loss, all_accs/total
218
 
219
  def train_batch(self, batch, batch_idx, device):
220
- x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
 
 
221
  x = x.to(device).float()
222
  fft = fft.to(device).float()
 
223
  y = y.to(device).float()
224
  x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
225
  y_pred = self.model(x_fft).squeeze()
@@ -255,7 +266,8 @@ class Trainer(object):
255
  return val_loss, all_accs/total
256
 
257
  def eval_batch(self, batch, batch_idx, device):
258
- x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
 
259
  x = x.to(device).float()
260
  fft = fft.to(device).float()
261
  x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
@@ -279,7 +291,7 @@ class Trainer(object):
279
  true_labels = []
280
  pbar = tqdm(test_dataloader)
281
  for i,batch in enumerate(pbar):
282
- x, fft, y = batch['audio']['array'], batch['audio']['fft'], batch['label']
283
  x = x.to(device).float()
284
  fft = fft.to(device).float()
285
  x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
@@ -297,4 +309,411 @@ class Trainer(object):
297
  pbar.set_description("acc: {:.4f}".format(acc))
298
  if i > self.max_iter:
299
  break
300
- return predictions, true_labels, all_accs/total
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from collections import OrderedDict
10
  from tqdm import tqdm
11
  import torch.distributed as dist
12
+ import pandas as pd
13
+ import xgboost as xgb
14
+ from sklearn.metrics import accuracy_score, classification_report, roc_auc_score
15
+
16
+
17
+ from torch.nn import ModuleList
18
+ # from inr import INR
19
+ # from kan import FasterKAN
20
 
21
  class Trainer(object):
22
  """
 
225
  return train_loss, all_accs/total
226
 
227
  def train_batch(self, batch, batch_idx, device):
228
+ x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label']
229
+ # features = batch['audio']['features_arr'].to(device).float()
230
+ # cwt = batch['audio']['cwt_mag']
231
  x = x.to(device).float()
232
  fft = fft.to(device).float()
233
+ # cwt = cwt.to(device).float()
234
  y = y.to(device).float()
235
  x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
236
  y_pred = self.model(x_fft).squeeze()
 
266
  return val_loss, all_accs/total
267
 
268
  def eval_batch(self, batch, batch_idx, device):
269
+ x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label']
270
+ # features = batch['audio']['features_arr'].to(device).float()
271
  x = x.to(device).float()
272
  fft = fft.to(device).float()
273
  x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
 
291
  true_labels = []
292
  pbar = tqdm(test_dataloader)
293
  for i,batch in enumerate(pbar):
294
+ x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label']
295
  x = x.to(device).float()
296
  fft = fft.to(device).float()
297
  x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1)
 
309
  pbar.set_description("acc: {:.4f}".format(acc))
310
  if i > self.max_iter:
311
  break
312
+ return predictions, true_labels, all_accs/total
313
+
314
+
315
+ class INRDatabase:
316
+ """Database to store and manage INRs persistently."""
317
+
318
+ def __init__(self, save_dir='./inr_database'):
319
+ self.inrs = {} # Maps sample_id -> INR
320
+ self.optimizers = {} # Maps sample_id -> optimizer state
321
+ self.save_dir = save_dir
322
+ os.makedirs(save_dir, exist_ok=True)
323
+
324
+ def get_or_create_inr(self, sample_id, create_fn, device):
325
+ """Get existing INR or create new one if not exists."""
326
+ if sample_id not in self.inrs:
327
+ # Create new INR
328
+ inr = create_fn().to(device)
329
+ optimizer = torch.optim.Adam(inr.parameters())
330
+ self.inrs[sample_id] = inr
331
+ self.optimizers[sample_id] = optimizer
332
+ return self.inrs[sample_id], self.optimizers[sample_id]
333
+
334
+ def set_inr(self, sample_id, inr, optimizer):
335
+ self.inrs[sample_id] = inr
336
+ self.optimizers[sample_id] = optimizer
337
+
338
+ def save_state(self):
339
+ """Save all INRs and optimizer states to disk."""
340
+ state = {
341
+ 'inrs': {
342
+ sample_id: inr.state_dict()
343
+ for sample_id, inr in self.inrs.items()
344
+ },
345
+ 'optimizers': {
346
+ sample_id: opt.state_dict()
347
+ for sample_id, opt in self.optimizers.items()
348
+ }
349
+ }
350
+ torch.save(state, os.path.join(self.save_dir, 'inr_database.pt'))
351
+
352
+ def load_state(self, create_fn, device):
353
+ """Load INRs and optimizer states from disk."""
354
+ path = os.path.join(self.save_dir, 'inr_database.pt')
355
+ if os.path.exists(path):
356
+ state = torch.load(path, map_location=device)
357
+
358
+ # Restore INRs
359
+ for sample_id, inr_state in state['inrs'].items():
360
+ inr = create_fn().to(device)
361
+ inr.load_state_dict(inr_state)
362
+ self.inrs[sample_id] = inr
363
+
364
+ # Restore optimizers
365
+ for sample_id, opt_state in state['optimizers'].items():
366
+ optimizer = torch.optim.Adam(self.inrs[sample_id].parameters())
367
+ optimizer.load_state_dict(opt_state)
368
+ self.optimizers[sample_id] = optimizer
369
+
370
+
371
+ class INRTrainer(Trainer):
372
+ def __init__(self, hidden_features=128, n_layers=3, in_features=1, out_features=1,
373
+ num_steps=5000, lr=1e-3, inr_criterion=torch.nn.MSELoss(), save_dir='./inr_database', *args, **kwargs):
374
+ super().__init__(*args, **kwargs)
375
+ self.hidden_features = hidden_features
376
+ self.n_layers = n_layers
377
+ self.in_features = in_features
378
+ self.out_features = out_features
379
+ self.num_steps = num_steps
380
+ self.lr = lr
381
+ self.inr_criterion = inr_criterion
382
+
383
+ # Initialize INR database
384
+ self.db = INRDatabase(save_dir)
385
+
386
+ # Load existing INRs if available
387
+ self.db.load_state(self.create_inr, self.device)
388
+
389
+ def create_inr(self):
390
+ """Factory function to create new INR instances."""
391
+ return INR(
392
+ hidden_features=self.hidden_features,
393
+ n_layers=self.n_layers,
394
+ in_features=self.in_features,
395
+ out_features=self.out_features
396
+ )
397
+
398
+ def create_kan(self):
399
+ return FasterKAN(layers_hidden=[self.in_features] + [self.hidden_features] * (self.n_layers) + [self.out_features],)
400
+
401
+ def get_sample_id(self, batch, idx):
402
+ """Extract unique identifier for a sample in the batch.
403
+ Override this method based on your data structure."""
404
+ # Example: if your batch contains unique IDs
405
+ if 'id' in batch:
406
+ return batch['id'][idx]
407
+ # Fallback: create hash from the sample data
408
+ sample_data = batch['audio']['array'][idx]
409
+ return hash(sample_data.cpu().numpy().tobytes())
410
+
411
+ def train_inr(self, optimizer, model, coords, values, num_iters=10, plot=False):
412
+ # pbar = tqdm(range(num_iters))
413
+ for _ in range(num_iters):
414
+ optimizer.zero_grad()
415
+ pred_values = model(coords.to(self.device)).float()
416
+ loss = self.inr_criterion(pred_values.squeeze(), values)
417
+ loss.backward()
418
+ optimizer.step()
419
+ # pbar.set_description(f'loss: {loss.item()}')
420
+ if plot:
421
+ plt.plot(values.cpu().detach().numpy())
422
+ plt.plot(pred_values.cpu().detach().numpy())
423
+ plt.title(loss.item())
424
+ plt.show()
425
+ return model, optimizer
426
+
427
+ def train_batch(self, batch, batch_idx, device):
428
+ """Train INRs for each sample in batch, persisting progress."""
429
+ coords = batch['audio']['coords'].to(device) # [B, N, 1]
430
+ fft = batch['audio']['fft_mag'].to(device) # [B, N]
431
+ audio = batch['audio']['array'].to(device) # [B, N]
432
+ y = batch['label'].to(device).float()
433
+
434
+ batch_size = coords.shape[0]
435
+
436
+ values = audio
437
+
438
+ batch_losses = []
439
+ batch_optimizers = []
440
+ batch_inrs = []
441
+ batch_weights = tuple()
442
+ batch_biases = tuple()
443
+ # Training loop
444
+ # pbar = tqdm(range(self.num_steps), desc="Training INRs")
445
+ plot = batch_idx == 0
446
+ for i in range(batch_size):
447
+ sample_id = self.get_sample_id(batch, i)
448
+ inr, optimizer = self.db.get_or_create_inr(sample_id, self.create_inr, device)
449
+ inr, optimizer = self.train_inr(optimizer, inr, coords[i], values[i])
450
+ self.db.set_inr(sample_id, inr, optimizer)
451
+ # pred_values = inr(coords[i]).squeeze()
452
+ # batch_losses.append(self.inr_criterion(pred_values, values[i]))
453
+ # batch_optimizers.append(optimizer)
454
+ state_dict = inr.state_dict()
455
+ weights = tuple(
456
+ [v.permute(1, 0).unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "weight" in w]
457
+ )
458
+ biases = tuple([v.unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "bias" in w])
459
+ if not len(batch_weights):
460
+ batch_weights = weights
461
+ else:
462
+ batch_weights = tuple(
463
+ [torch.cat((weights[i], batch_weights[i]), dim=0) for i in range(len(weights))]
464
+ )
465
+ if not len(batch_biases):
466
+ batch_biases = biases
467
+ else:
468
+ batch_biases = tuple(
469
+ [torch.cat((biases[i], batch_biases[i]), dim=0) for i in range(len(biases))]
470
+ )
471
+ # loss_preds = torch.tensor([0])
472
+ # acc = 0
473
+ y_pred = self.model(inputs=(batch_weights, batch_biases)).squeeze()
474
+ loss_preds = self.criterion(y_pred, y)
475
+ self.optimizer.zero_grad()
476
+ loss_preds.backward()
477
+ self.optimizer.step()
478
+ # for i in range(batch_size):
479
+ # batch_optimizers[i].zero_grad()
480
+ # batch_losses[i] += loss_preds
481
+ # batch_losses[i].backward()
482
+ # batch_optimizers[i].step()
483
+
484
+
485
+ if batch_idx % 10 == 0: # Adjust frequency as needed
486
+ self.db.save_state()
487
+
488
+ probs = torch.sigmoid(y_pred)
489
+ cls_pred = (probs > 0.5).float()
490
+ acc = (cls_pred == y).sum()
491
+
492
+
493
+ return loss_preds, acc, y
494
+
495
+ def eval_batch(self, batch, batch_idx, device):
496
+ """Evaluate INRs for each sample in batch."""
497
+ coords = batch['audio']['coords'].to(device)
498
+ fft = batch['audio']['fft_mag'].to(device)
499
+ audio = batch['audio']['array'].to(device)
500
+
501
+ batch_size = coords.shape[0]
502
+ # values = torch.cat((
503
+ # audio.unsqueeze(-1),
504
+ # fft.unsqueeze(-1)
505
+ # ), dim=-1)
506
+ values = audio
507
+ # Get INRs for each sample
508
+ batch_inrs = []
509
+ for i in range(batch_size):
510
+ sample_id = self.get_sample_id(batch, i)
511
+ inr, _ = self.db.get_or_create_inr(sample_id, self.create_inr, device)
512
+ batch_inrs.append(inr)
513
+
514
+ # Evaluate
515
+ with torch.no_grad():
516
+ all_preds = torch.stack([
517
+ inr(coords[i])
518
+ for i, inr in enumerate(batch_inrs)
519
+ ])
520
+
521
+ batch_losses = torch.stack([
522
+ self.criterion(all_preds[i].squeeze(), values[i])
523
+ for i in range(batch_size)
524
+ ])
525
+
526
+ avg_loss = batch_losses.mean().item()
527
+
528
+ acc = torch.zeros(self.output_dim, device=device)
529
+ y = values
530
+
531
+ return torch.tensor(avg_loss), acc, y
532
+
533
+
534
+ def verify_parallel_gradient_isolation(trainer, batch_size=4, sequence_length=1000):
535
+ """
536
+ Verify that gradients remain isolated in parallel training.
537
+ """
538
+ device = trainer.device
539
+
540
+ # Create test data
541
+ coords = torch.linspace(0, 1, sequence_length).unsqueeze(-1) # [N, 1]
542
+ coords = coords.unsqueeze(0).repeat(batch_size, 1, 1) # [B, N, 1]
543
+
544
+ # Create synthetic signals
545
+ targets = torch.stack([
546
+ torch.sin(2 * torch.pi * (i + 1) * coords.squeeze(-1))
547
+ for i in range(batch_size)
548
+ ]).to(device)
549
+
550
+ # Create batch of INRs
551
+ inrs = trainer.create_batch_inrs()
552
+
553
+ # Store initial parameters
554
+ initial_params = [{name: param.clone().detach()
555
+ for name, param in inr.named_parameters()}
556
+ for inr in inrs]
557
+
558
+ # Create mock batch
559
+ batch = {
560
+ 'audio': {
561
+ 'coords': coords.to(device),
562
+ 'fft_mag': targets.clone(),
563
+ 'array': targets.clone()
564
+ }
565
+ }
566
+
567
+ # Run one training step
568
+ trainer.train_batch(batch, 0, device)
569
+
570
+ # Verify parameter changes
571
+ isolation_verified = True
572
+ for i, inr in enumerate(inrs):
573
+ params_changed = False
574
+ for name, param in inr.named_parameters():
575
+ if not torch.allclose(param, initial_params[i][name]):
576
+ params_changed = True
577
+ # Verify that the changes are unique to this INR
578
+ for j, other_inr in enumerate(inrs):
579
+ if i != j:
580
+ other_param = dict(other_inr.named_parameters())[name]
581
+ if not torch.allclose(other_param, initial_params[j][name]):
582
+ isolation_verified = False
583
+ print(f"Warning: Parameter {name} of INR {j} changed when only INR {i} should have changed")
584
+
585
+ return isolation_verified
586
+
587
+ class XGBoostTrainer():
588
+ def __init__(self, model_args, train_ds, val_ds, test_ds):
589
+ self.train_ds = train_ds
590
+ self.test_ds = test_ds
591
+ print("creating train dataframe...")
592
+ self.x_train, self.y_train = self.create_dataframe(train_ds, save_name='train')
593
+ print("creating validation dataframe...")
594
+ self.x_val, self.y_val = self.create_dataframe(val_ds, save_name='val')
595
+ print("creating test dataframe...")
596
+ self.x_test, self.y_test = self.create_dataframe(test_ds, save_name='test')
597
+
598
+ # Convert the data to DMatrix format
599
+ self.dtrain = xgb.DMatrix(self.x_train, label=self.y_train)
600
+ self.dval = xgb.DMatrix(self.x_val, label=self.y_val)
601
+ self.dtest = xgb.DMatrix(self.x_test, label=self.y_test)
602
+
603
+ # Model initialization
604
+ self.model_args = model_args
605
+ self.model = xgb.XGBClassifier(**model_args)
606
+
607
+ def create_dataframe(self, ds, save_name='train'):
608
+ try:
609
+ df = pd.read_csv(f"tasks/utils/dfs/{save_name}.csv")
610
+ except FileNotFoundError:
611
+ data = []
612
+
613
+ # Iterate over the dataset
614
+ pbar = tqdm(enumerate(ds))
615
+ for i, batch in pbar:
616
+ label = batch['label']
617
+ features = batch['audio']['features']
618
+
619
+ # Flatten the nested dictionary structure
620
+ feature_dict = {'label': label}
621
+ for k, v in features.items():
622
+ if isinstance(v, dict):
623
+ for sub_k, sub_v in v.items():
624
+ feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean)
625
+ data.append(feature_dict)
626
+ # Convert to DataFrame
627
+ df = pd.DataFrame(data)
628
+ print(os.getcwd())
629
+ df.to_csv(f"tasks/utils/dfs/{save_name}.csv", index=False)
630
+ X = df.drop(columns=['label'])
631
+ y = df['label']
632
+ return X, y
633
+
634
+ def fit(self):
635
+ # Train using the `train` method with early stopping
636
+ params = {
637
+ 'objective': 'binary:logistic',
638
+ 'eval_metric': 'logloss',
639
+ **self.model_args
640
+ }
641
+
642
+ evals_result = {}
643
+ num_boost_round = 1000 # Set a large number of boosting rounds
644
+
645
+ # Watchlist to monitor performance on train and validation data
646
+ watchlist = [(self.dtrain, 'train'), (self.dval, 'eval')]
647
+
648
+ # Train the model
649
+ self.model = xgb.train(
650
+ params,
651
+ self.dtrain,
652
+ num_boost_round=num_boost_round,
653
+ evals=watchlist,
654
+ early_stopping_rounds=10, # Early stopping after 10 rounds with no improvement
655
+ evals_result=evals_result,
656
+ verbose_eval=True # Show evaluation results for each iteration
657
+ )
658
+
659
+ return evals_result
660
+
661
+ def train_xgboost_in_batches(self, dataloader, eval_metric="logloss"):
662
+ evals_result = {}
663
+ for i, batch in enumerate(dataloader):
664
+ # Convert batch data to NumPy arrays
665
+ X_batch = torch.cat([batch[key].view(batch[key].size(0), -1) for key in batch if key != "label"],
666
+ dim=1).numpy()
667
+ y_batch = batch["label"].numpy()
668
+
669
+ # Create DMatrix for XGBoost
670
+ dtrain = xgb.DMatrix(X_batch, label=y_batch)
671
+
672
+ # Use `train` with each batch
673
+ self.model = xgb.train(
674
+ params,
675
+ dtrain,
676
+ num_boost_round=1000, # Use a large number of rounds
677
+ evals=[(self.dval, 'eval')],
678
+ eval_metric=eval_metric,
679
+ early_stopping_rounds=10,
680
+ evals_result=evals_result,
681
+ verbose_eval=False # Avoid printing every iteration
682
+ )
683
+
684
+ # Optionally print progress
685
+ if i % 10 == 0:
686
+ print(f"Batch {i + 1}/{len(dataloader)} processed.")
687
+
688
+ return evals_result
689
+
690
+ def predict(self):
691
+ # Predict probabilities for class 1
692
+ y_prob = self.model.predict(self.dtest, output_margin=False)
693
+
694
+ # Convert probabilities to binary labels (0 or 1) using a threshold (e.g., 0.5)
695
+ y_pred = (y_prob >= 0.5).astype(int)
696
+
697
+ # Evaluate performance
698
+ accuracy = accuracy_score(self.y_test, y_pred)
699
+ roc_auc = roc_auc_score(self.y_test, y_prob)
700
+
701
+ print(f'Accuracy: {accuracy:.4f}')
702
+ print(f'ROC AUC Score: {roc_auc:.4f}')
703
+ print(classification_report(self.y_test, y_pred))
704
+
705
+ def plot_results(self, evals_result):
706
+ train_logloss = evals_result["train"]["logloss"]
707
+ val_logloss = evals_result["eval"]["logloss"]
708
+ iterations = range(1, len(train_logloss) + 1)
709
+
710
+ # Plot
711
+ plt.figure(figsize=(8, 5))
712
+ plt.plot(iterations, train_logloss, label="Train LogLoss", color="blue")
713
+ plt.plot(iterations, val_logloss, label="Validation LogLoss", color="red")
714
+ plt.xlabel("Boosting Round (Iteration)")
715
+ plt.ylabel("Log Loss")
716
+ plt.title("Training and Validation Log Loss over Iterations")
717
+ plt.legend()
718
+ plt.grid()
719
+ plt.show()
tasks/utils/transforms.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import librosa
3
+ import torch
4
+ import torch.nn as nn
5
+ import pywt
6
+ from scipy import signal
7
+
8
+
9
+
10
+ def compute_cwt_power_spectrum(audio, sample_rate, num_freqs=128, f_min=20, f_max=None):
11
+ """
12
+ Compute the power spectrum of continuous wavelet transform using Morlet wavelet.
13
+
14
+ Parameters:
15
+ audio: torch.Tensor
16
+ Input audio signal
17
+ sample_rate: int
18
+ Sampling rate of the audio
19
+ num_freqs: int
20
+ Number of frequency bins for the CWT
21
+ f_min: float
22
+ Minimum frequency to analyze
23
+ f_max: float or None
24
+ Maximum frequency to analyze (defaults to Nyquist frequency)
25
+
26
+ Returns:
27
+ torch.Tensor: CWT power spectrum
28
+ """
29
+ # Convert to numpy
30
+ audio_np = audio.cpu().numpy()
31
+
32
+ # Set default f_max to Nyquist frequency if not specified
33
+ if f_max is None:
34
+ f_max = sample_rate // 2
35
+
36
+ # Generate frequency bins (logarithmically spaced)
37
+ frequencies = np.logspace(
38
+ np.log10(f_min),
39
+ np.log10(f_max),
40
+ num=num_freqs
41
+ )
42
+
43
+ # Compute the width of the wavelet (in samples)
44
+ widths = sample_rate / (2 * frequencies * np.pi)
45
+
46
+ # Compute CWT using Morlet wavelet
47
+ cwt = signal.cwt(
48
+ audio_np,
49
+ signal.morlet2,
50
+ widths,
51
+ w=5.0 # Width parameter of Morlet wavelet
52
+ )
53
+
54
+ # Compute power spectrum (magnitude squared)
55
+ power_spectrum = np.abs(cwt) ** 2
56
+
57
+ # Convert to torch tensor
58
+ power_spectrum_tensor = torch.FloatTensor(power_spectrum)
59
+
60
+ return power_spectrum_tensor
61
+
62
+ def compute_wavelet_transform(audio, wavelet, decompos_level):
63
+ """Compute wavelet decomposition of the audio signal."""
64
+ # Convert to numpy and ensure 1D
65
+ audio_np = audio.cpu().numpy()
66
+
67
+ # Perform wavelet decomposition
68
+ coeffs = pywt.wavedec(audio_np, wavelet, level=decompos_level)
69
+
70
+ # Stack coefficients into a 2D array
71
+ # First, pad all coefficient arrays to the same length
72
+ max_len = max(len(c) for c in coeffs)
73
+ padded_coeffs = []
74
+ for coeff in coeffs:
75
+ pad_len = max_len - len(coeff)
76
+ if pad_len > 0:
77
+ padded_coeff = np.pad(coeff, (0, pad_len), mode='constant')
78
+ else:
79
+ padded_coeff = coeff
80
+ padded_coeffs.append(padded_coeff)
81
+
82
+ # Stack into 2D array where each row is a different scale
83
+ wavelet_features = np.stack(padded_coeffs)
84
+
85
+ # Convert to tensor
86
+ return torch.FloatTensor(wavelet_features)
87
+
88
+
89
+ def compute_melspectrogram(audio, sample_rate):
90
+ mel_spec = librosa.feature.melspectrogram(
91
+ y=audio.cpu().numpy(),
92
+ sr=sample_rate,
93
+ n_mels=128
94
+ )
95
+ return torch.FloatTensor(librosa.power_to_db(mel_spec))
96
+
97
+
98
+ def compute_mfcc(audio, sample_rate):
99
+ mfcc = librosa.feature.mfcc(
100
+ y=audio.cpu().numpy(),
101
+ sr=sample_rate,
102
+ n_mfcc=20
103
+ )
104
+ return torch.FloatTensor(mfcc)
105
+
106
+
107
+ def compute_chroma(audio, sample_rate):
108
+ chroma = librosa.feature.chroma_stft(
109
+ y=audio.cpu().numpy(),
110
+ sr=sample_rate
111
+ )
112
+ return torch.FloatTensor(chroma)
113
+
114
+
115
+ def compute_time_domain_features(audio, sample_rate, frame_length=2048, hop_length=128):
116
+ """
117
+ Compute time-domain features from audio signal.
118
+ Returns a dictionary of features.
119
+ """
120
+ # Convert to numpy
121
+ audio_np = audio.cpu().numpy()
122
+
123
+ # Initialize dictionary for features
124
+ features = {}
125
+
126
+ # 1. Zero Crossing Rate
127
+ zcr = librosa.feature.zero_crossing_rate(
128
+ y=audio_np,
129
+ frame_length=frame_length,
130
+ hop_length=hop_length
131
+ )
132
+ features['zcr'] = torch.Tensor([zcr.sum()])
133
+
134
+ # 2. Root Mean Square Energy
135
+ rms = librosa.feature.rms(
136
+ y=audio_np,
137
+ frame_length=frame_length,
138
+ hop_length=hop_length
139
+ )
140
+ features['rms_energy'] = torch.Tensor([rms.mean()])
141
+
142
+ # 3. Temporal Statistics
143
+ frames = librosa.util.frame(audio_np, frame_length=frame_length, hop_length=hop_length)
144
+ features['mean'] = torch.Tensor([np.mean(frames, axis=0).mean()])
145
+ features['std'] = torch.Tensor([np.std(frames, axis=0).mean()])
146
+ features['max'] = torch.Tensor([np.max(frames, axis=0).mean()])
147
+
148
+ # 4. Tempo and Beat Features
149
+ onset_env = librosa.onset.onset_strength(y=audio_np, sr=sample_rate)
150
+ tempo = librosa.beat.tempo(onset_envelope=onset_env, sr=sample_rate)
151
+ features['tempo'] = torch.Tensor(tempo)
152
+
153
+ # 5. Amplitude Envelope
154
+ envelope = np.abs(librosa.stft(audio_np, n_fft=frame_length, hop_length=hop_length))
155
+ features['envelope'] = torch.Tensor([np.mean(envelope, axis=0).mean()])
156
+
157
+ return features
158
+
159
+
160
+ def compute_frequency_domain_features(audio, sample_rate, n_fft=2048, hop_length=512):
161
+ """
162
+ Compute frequency-domain features from audio signal.
163
+ Returns a dictionary of features.
164
+ """
165
+ # Convert to numpy
166
+ audio_np = audio.cpu().numpy()
167
+
168
+ # Initialize dictionary for features
169
+ features = {}
170
+
171
+ # 1. Spectral Centroid
172
+ try:
173
+ spectral_centroids = librosa.feature.spectral_centroid(
174
+ y=audio_np,
175
+ sr=sample_rate,
176
+ n_fft=n_fft,
177
+ hop_length=hop_length,
178
+
179
+ )
180
+ features['spectral_centroid'] = torch.FloatTensor([spectral_centroids.max()])
181
+ except Exception as e:
182
+ features['spectral_centroid'] = torch.FloatTensor([np.nan])
183
+
184
+ # 2. Spectral Rolloff
185
+ try:
186
+ spectral_rolloff = librosa.feature.spectral_rolloff(
187
+ y=audio_np,
188
+ sr=sample_rate,
189
+ n_fft=n_fft,
190
+ hop_length=hop_length,
191
+
192
+ )
193
+ features['spectral_rolloff'] = torch.FloatTensor([spectral_rolloff.max()])
194
+ except Exception as e:
195
+ features['spectral_rolloff'] = torch.FloatTensor([np.nan])
196
+
197
+ # 3. Spectral Bandwidth
198
+ try:
199
+ spectral_bandwidth = librosa.feature.spectral_bandwidth(
200
+ y=audio_np,
201
+ sr=sample_rate,
202
+ n_fft=n_fft,
203
+ hop_length=hop_length
204
+ )
205
+ features['spectral_bandwidth'] = torch.FloatTensor([spectral_bandwidth.max()])
206
+ except Exception as e:
207
+ features['spectral_bandwidth'] = torch.FloatTensor([np.nan])
208
+ # 4. Spectral Contrast
209
+ try:
210
+ spectral_contrast = librosa.feature.spectral_contrast(
211
+ y=audio_np,
212
+ sr=sample_rate,
213
+ n_fft=n_fft,
214
+ hop_length=hop_length,
215
+ fmin=20, # Lower minimum frequency
216
+ n_bands=4, # Reduce number of bands
217
+ quantile=0.02
218
+ )
219
+ features['spectral_contrast'] = torch.FloatTensor([spectral_contrast.mean()])
220
+ except Exception as e:
221
+ features['spectral_contrast'] = torch.FloatTensor([np.nan])
222
+
223
+ # 5. Spectral Flatness
224
+ try:
225
+ spectral_flatness = librosa.feature.spectral_flatness(
226
+ y=audio_np,
227
+ n_fft=n_fft,
228
+ hop_length=hop_length
229
+ )
230
+ features['spectral_flatness'] = torch.FloatTensor([spectral_flatness.max()])
231
+ except Exception as e:
232
+ features['spectral_flatness'] = torch.FloatTensor([np.nan])
233
+
234
+ # 6. Spectral Flux
235
+ try:
236
+ stft = np.abs(librosa.stft(audio_np, n_fft=n_fft, hop_length=hop_length))
237
+ spectral_flux = np.diff(stft, axis=1)
238
+ spectral_flux = np.pad(spectral_flux, ((0, 0), (1, 0)), mode='constant')
239
+ features['spectral_flux'] = torch.FloatTensor([np.std(spectral_flux)])
240
+ except Exception as e:
241
+ features['spectral_flux'] = torch.FloatTensor([np.nan])
242
+
243
+ return features
244
+
245
+
246
+ def compute_all_features(audio, sample_rate, wavelet='db1', decompos_level=4):
247
+ """
248
+ Compute all available features and return them in a dictionary.
249
+ """
250
+ features = {}
251
+
252
+ # Basic transformations
253
+ # features['wavelet'] = compute_wavelet_transform(audio, wavelet, decompos_level)
254
+ # features['melspectrogram'] = compute_melspectrogram(audio, sample_rate)
255
+ # features['mfcc'] = compute_mfcc(audio, sample_rate)
256
+ # features['chroma'] = compute_chroma(audio, sample_rate)
257
+
258
+ # features['cwt_power'] = compute_cwt_power_spectrum(
259
+ # audio,
260
+ # sample_rate,
261
+ # num_freqs=128, # Same as mel bands for consistency
262
+ # f_min=20, # Standard lower frequency bound
263
+ # f_max=sample_rate // 2 # Nyquist frequency
264
+ # )
265
+
266
+ # Time domain features
267
+ # features['time_domain'] = compute_time_domain_features(audio, sample_rate)
268
+
269
+ # Frequency domain features
270
+ features['frequency_domain'] = compute_frequency_domain_features(audio, sample_rate)
271
+
272
+ return features
test ADDED
Binary file (70 kB). View file