IlayMalinyak commited on
Commit
72a8e1c
·
1 Parent(s): 2f54ec8

remove pwvt

Browse files
Files changed (3) hide show
  1. tasks/run_inr.py +29 -29
  2. tasks/test +0 -0
  3. tasks/utils/transforms.py +26 -26
tasks/run_inr.py CHANGED
@@ -58,11 +58,11 @@ local_rank = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
58
  login(api_key)
59
  dataset = load_dataset("rfcx/frugalai", streaming=True)
60
 
61
- train_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=True)
62
 
63
  train_dl = DataLoader(train_ds, batch_size=data_args.batch_size)
64
 
65
- val_ds = SplitDataset(FFTDataset(dataset["train"]), is_train=False)
66
 
67
  val_dl = DataLoader(val_ds, batch_size=data_args.batch_size)
68
 
@@ -110,34 +110,34 @@ test_dl = DataLoader(test_ds, batch_size=data_args.batch_size)
110
  loss_fn = torch.nn.BCEWithLogitsLoss()
111
  inr_criterion = torch.nn.MSELoss()
112
 
113
- # for i, batch in enumerate(train_ds):
114
- # coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array']
115
- # coords = coords.to(local_rank)
116
- # fft = fft.to(local_rank)
117
- # audio = audio.to(local_rank)
118
- # values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1)
119
- # # model = INR(hidden_features=128, n_layers=3,
120
- # # in_features=1,
121
- # # out_features=1).to(local_rank)
122
- # model = FasterKAN(**kan_args.get_dict()).to(local_rank)
123
- # optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
124
- # pbar = tqdm(range(200))
125
- # losses = []
126
- # print(coords.shape)
127
- # for t in pbar:
128
- # optimizer.zero_grad()
129
- # pred_values = model(coords.to(local_rank)).float()
130
- # loss = inr_criterion(pred_values, values)
131
- # loss.backward()
132
- # optimizer.step()
133
- # pbar.set_description(f'loss: {loss.item()}')
134
- # losses.append(loss.item())
135
- # state_dict = model.state_dict()
136
- # torch.save(state_dict, 'test')
137
- # # print(f'Sample {i+offset} label {label} saved in {inr_path}')
138
- # plot_results(1, i, fft, losses, pred_values)
139
  # #
140
- # exit()
141
 
142
 
143
  # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
 
58
  login(api_key)
59
  dataset = load_dataset("rfcx/frugalai", streaming=True)
60
 
61
+ train_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=True)
62
 
63
  train_dl = DataLoader(train_ds, batch_size=data_args.batch_size)
64
 
65
+ val_ds = SplitDataset(AudioINRDataset(FFTDataset(dataset["train"])), is_train=False)
66
 
67
  val_dl = DataLoader(val_ds, batch_size=data_args.batch_size)
68
 
 
110
  loss_fn = torch.nn.BCEWithLogitsLoss()
111
  inr_criterion = torch.nn.MSELoss()
112
 
113
+ for i, batch in enumerate(train_ds):
114
+ coords, fft, audio = batch['audio']['coords'], batch['audio']['fft_mag'], batch['audio']['array']
115
+ coords = coords.to(local_rank)
116
+ fft = fft.to(local_rank)
117
+ audio = audio.to(local_rank)
118
+ # values = torch.cat((audio.unsqueeze(-1), fft.unsqueeze(-1)), dim=-1)
119
+ model = INR(hidden_features=128, n_layers=4,
120
+ in_features=1,
121
+ out_features=1).to(local_rank)
122
+ # model = FasterKAN(layers_hidden=[1,16,16,1]).to(local_rank)
123
+ optimizer = torch.optim.Adam([{'params': model.parameters()}], lr=1e-3)
124
+ pbar = tqdm(range(200))
125
+ losses = []
126
+ print(coords.shape)
127
+ for t in pbar:
128
+ optimizer.zero_grad()
129
+ pred_values = model(coords.to(local_rank)).float()
130
+ loss = inr_criterion(pred_values, fft)
131
+ loss.backward()
132
+ optimizer.step()
133
+ pbar.set_description(f'loss: {loss.item()}')
134
+ losses.append(loss.item())
135
+ state_dict = model.state_dict()
136
+ torch.save(state_dict, 'test')
137
+ # print(f'Sample {i+offset} label {label} saved in {inr_path}')
138
+ plot_results(1, i, fft, losses, pred_values)
139
  # #
140
+ exit()
141
 
142
 
143
  # missing, unexpected = model.load_state_dict(torch.load(model_args.checkpoint_path))
tasks/test ADDED
Binary file (136 kB). View file
 
tasks/utils/transforms.py CHANGED
@@ -2,7 +2,7 @@ import numpy as np
2
  import librosa
3
  import torch
4
  import torch.nn as nn
5
- import pywt
6
  from scipy import signal
7
 
8
 
@@ -59,31 +59,31 @@ def compute_cwt_power_spectrum(audio, sample_rate, num_freqs=128, f_min=20, f_ma
59
 
60
  return power_spectrum_tensor
61
 
62
- def compute_wavelet_transform(audio, wavelet, decompos_level):
63
- """Compute wavelet decomposition of the audio signal."""
64
- # Convert to numpy and ensure 1D
65
- audio_np = audio.cpu().numpy()
66
-
67
- # Perform wavelet decomposition
68
- coeffs = pywt.wavedec(audio_np, wavelet, level=decompos_level)
69
-
70
- # Stack coefficients into a 2D array
71
- # First, pad all coefficient arrays to the same length
72
- max_len = max(len(c) for c in coeffs)
73
- padded_coeffs = []
74
- for coeff in coeffs:
75
- pad_len = max_len - len(coeff)
76
- if pad_len > 0:
77
- padded_coeff = np.pad(coeff, (0, pad_len), mode='constant')
78
- else:
79
- padded_coeff = coeff
80
- padded_coeffs.append(padded_coeff)
81
-
82
- # Stack into 2D array where each row is a different scale
83
- wavelet_features = np.stack(padded_coeffs)
84
-
85
- # Convert to tensor
86
- return torch.FloatTensor(wavelet_features)
87
 
88
 
89
  def compute_melspectrogram(audio, sample_rate):
 
2
  import librosa
3
  import torch
4
  import torch.nn as nn
5
+ # import pywt
6
  from scipy import signal
7
 
8
 
 
59
 
60
  return power_spectrum_tensor
61
 
62
+ # def compute_wavelet_transform(audio, wavelet, decompos_level):
63
+ # """Compute wavelet decomposition of the audio signal."""
64
+ # # Convert to numpy and ensure 1D
65
+ # audio_np = audio.cpu().numpy()
66
+ #
67
+ # # Perform wavelet decomposition
68
+ # coeffs = pywt.wavedec(audio_np, wavelet, level=decompos_level)
69
+ #
70
+ # # Stack coefficients into a 2D array
71
+ # # First, pad all coefficient arrays to the same length
72
+ # max_len = max(len(c) for c in coeffs)
73
+ # padded_coeffs = []
74
+ # for coeff in coeffs:
75
+ # pad_len = max_len - len(coeff)
76
+ # if pad_len > 0:
77
+ # padded_coeff = np.pad(coeff, (0, pad_len), mode='constant')
78
+ # else:
79
+ # padded_coeff = coeff
80
+ # padded_coeffs.append(padded_coeff)
81
+ #
82
+ # # Stack into 2D array where each row is a different scale
83
+ # wavelet_features = np.stack(padded_coeffs)
84
+ #
85
+ # # Convert to tensor
86
+ # return torch.FloatTensor(wavelet_features)
87
 
88
 
89
  def compute_melspectrogram(audio, sample_rate):