import torch from torch.cuda.amp import autocast import numpy as np import time import os import yaml from matplotlib import pyplot as plt import glob from collections import OrderedDict from tqdm import tqdm import torch.distributed as dist import pandas as pd import xgboost as xgb from sklearn.metrics import accuracy_score, classification_report, roc_auc_score from torch.nn import ModuleList # from inr import INR # from kan import FasterKAN class Trainer(object): """ A class that encapsulates the training loop for a PyTorch model. """ def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2, scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None, grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None, cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1, update_func=lambda x: x): self.model = model self.optimizer = optimizer self.criterion = criterion self.scaler = scaler self.grad_clip = grad_clip self.cos_inc = cos_inc self.output_dim = output_dim self.scheduler = scheduler self.train_dl = train_dataloader self.val_dl = val_dataloader self.train_sampler = self.get_sampler_from_dataloader(train_dataloader) self.val_sampler = self.get_sampler_from_dataloader(val_dataloader) self.max_iter = max_iter self.device = device self.world_size = world_size self.exp_num = exp_num self.exp_name = exp_name self.log_path = log_path self.best_state_dict = None self.plot_every = plot_every self.logger = None self.range_update = range_update self.accumulation_step = accumulation_step self.wandb = wandb_log self.num_quantiles = num_quantiles self.update_func = update_func # if log_path is not None: # self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}') # # print(f"logger path: {self.log_path}/exp{self.exp_num}") # print("logger is: ", self.logger) def get_sampler_from_dataloader(self, dataloader): if hasattr(dataloader, 'sampler'): if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): return dataloader.sampler elif hasattr(dataloader.sampler, 'sampler'): return dataloader.sampler.sampler if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'): return dataloader.batch_sampler.sampler return None def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False): """ Fits the model for the given number of epochs. """ min_loss = np.inf best_acc = 0 train_loss, val_loss, = [], [] train_acc, val_acc = [], [] lrs = [] # self.optim_params['lr_history'] = [] epochs_without_improvement = 0 # main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu' main_proccess = True # change in a ddp setting print(f"Starting training for {num_epochs} epochs") print("is main process: ", main_proccess, flush=True) global_time = time.time() self.epoch = 0 for epoch in range(num_epochs): self.epoch = epoch start_time = time.time() plot = (self.plot_every is not None) and (epoch % self.plot_every == 0) t_loss, t_acc = self.train_epoch(device, epoch=epoch) t_loss_mean = np.nanmean(t_loss) train_loss.extend(t_loss) global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean) if main_proccess: # Only perform this on the master GPU train_acc.append(global_train_accuracy.mean().item()) v_loss, v_acc = self.eval_epoch(device, epoch=epoch) v_loss_mean = np.nanmean(v_loss) val_loss.extend(v_loss) global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean) if main_proccess: # Only perform this on the master GPU val_acc.append(global_val_accuracy.mean().item()) current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean() improved = False if best == 'loss': if current_objective < min_loss: min_loss = current_objective improved = True else: if current_objective > best_acc: best_acc = current_objective improved = True if improved: model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth' print(f"saving model at {model_name}...") torch.save(self.model.state_dict(), model_name) self.best_state_dict = self.model.state_dict() epochs_without_improvement = 0 else: epochs_without_improvement += 1 current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \ else self.scheduler.get_last_lr()[0] lrs.append(current_lr) print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\ f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\ f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\ f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True) if epoch % 10 == 0: print(os.system('nvidia-smi')) if epochs_without_improvement == early_stopping: print('early stopping!', flush=True) break if time.time() - global_time > (23.83 * 3600): print("time limit reached") break return {"num_epochs":num_epochs, "train_loss": train_loss, "val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs} def process_loss(self, acc, loss_mean): if torch.cuda.is_available() and torch.distributed.is_initialized(): global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM) global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM) # Divide both loss and accuracy by world size world_size = torch.distributed.get_world_size() global_loss /= world_size global_accuracy /= world_size else: global_loss = torch.tensor(loss_mean) global_accuracy = torch.tensor(acc) return global_accuracy, global_loss def load_best_model(self, to_ddp=True, from_ddp=True): data_dir = f'{self.log_path}/exp{self.exp_num}' # data_dir = f'{self.log_path}/exp29' # for debugging state_dict_files = glob.glob(data_dir + '/*.pth') print("loading model from ", state_dict_files[-1]) state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device) if from_ddp: print("loading distributed model") # Remove "module." from keys new_state_dict = OrderedDict() for key, value in state_dict.items(): if key.startswith('module.'): while key.startswith('module.'): key = key[7:] new_state_dict[key] = value state_dict = new_state_dict # print("state_dict: ", state_dict.keys()) # print("model: ", self.model.state_dict().keys()) self.model.load_state_dict(state_dict, strict=False) def check_gradients(self): for name, param in self.model.named_parameters(): if param.grad is not None: grad_norm = param.grad.norm().item() if grad_norm > 10: print(f"Large gradient in {name}: {grad_norm}") def train_epoch(self, device, epoch): """ Trains the model for one epoch. """ if self.train_sampler is not None: try: self.train_sampler.set_epoch(epoch) except AttributeError: pass self.model.train() train_loss = [] train_acc = 0 total = 0 all_accs = torch.zeros(self.output_dim, device=device) pbar = tqdm(self.train_dl) for i, batch in enumerate(pbar): if self.optimizer is not None: self.optimizer.zero_grad() loss, acc , y = self.train_batch(batch, i, device) train_loss.append(loss.item()) all_accs = all_accs + acc total += len(y) pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}") if i > self.max_iter: break print("number of train_accs: ", train_acc) return train_loss, all_accs/total def train_batch(self, batch, batch_idx, device): x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] # features = torch.stack(batch['audio']['features']).to(device).float() # cwt = batch['audio']['cwt_mag'] x = x.to(device).float() fft = fft.to(device).float() # cwt = cwt.to(device).float() y = y.to(device).float() x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) y_pred = self.model(x_fft).squeeze() loss = self.criterion(y_pred, y) loss.backward() self.optimizer.step() if self.scheduler is not None: self.scheduler.step() # get predicted classes probs = torch.sigmoid(y_pred) cls_pred = (probs > 0.5).float() acc = (cls_pred == y).sum() return loss, acc, y def eval_epoch(self, device, epoch): """ Evaluates the model for one epoch. """ self.model.eval() val_loss = [] val_acc = 0 total = 0 all_accs = torch.zeros(self.output_dim, device=device) pbar = tqdm(self.val_dl) for i,batch in enumerate(pbar): loss, acc, y = self.eval_batch(batch, i, device) val_loss.append(loss.item()) all_accs = all_accs + acc total += len(y) pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}") if i > self.max_iter: break return val_loss, all_accs/total def eval_batch(self, batch, batch_idx, device): x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] # features = torch.stack(batch['audio']['features']).to(device).float() # features = batch['audio']['features_arr'].to(device).float() x = x.to(device).float() fft = fft.to(device).float() x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) y = y.to(device).float() with torch.no_grad(): y_pred = self.model(x_fft).squeeze() loss = self.criterion(y_pred.squeeze(), y) probs = torch.sigmoid(y_pred) cls_pred = (probs > 0.5).float() acc = (cls_pred == y).sum() return loss, acc, y def predict(self, test_dataloader, device): """ Returns the predictions of the model on the given dataset. """ self.model.eval() total = 0 all_accs = 0 predictions = [] true_labels = [] pbar = tqdm(test_dataloader) for i,batch in enumerate(pbar): x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] # features = batch['audio']['features'] x = x.to(device).float() fft = fft.to(device).float() x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) y = y.to(device).float() with torch.no_grad(): y_pred = self.model(x_fft).squeeze() loss = self.criterion(y_pred, y) probs = torch.sigmoid(y_pred) cls_pred = (probs > 0.5).float() acc = (cls_pred == y).sum() predictions.extend(cls_pred.cpu().numpy()) true_labels.extend(y.cpu().numpy().astype(np.int64)) all_accs += acc total += len(y) pbar.set_description("acc: {:.4f}".format(acc)) if i > self.max_iter: break return predictions, true_labels, all_accs/total class INRDatabase: """Database to store and manage INRs persistently.""" def __init__(self, save_dir='./inr_database'): self.inrs = {} # Maps sample_id -> INR self.optimizers = {} # Maps sample_id -> optimizer state self.save_dir = save_dir os.makedirs(save_dir, exist_ok=True) def get_or_create_inr(self, sample_id, create_fn, device): """Get existing INR or create new one if not exists.""" if sample_id not in self.inrs: # Create new INR inr = create_fn().to(device) optimizer = torch.optim.Adam(inr.parameters()) self.inrs[sample_id] = inr self.optimizers[sample_id] = optimizer return self.inrs[sample_id], self.optimizers[sample_id] def set_inr(self, sample_id, inr, optimizer): self.inrs[sample_id] = inr self.optimizers[sample_id] = optimizer def save_state(self): """Save all INRs and optimizer states to disk.""" state = { 'inrs': { sample_id: inr.state_dict() for sample_id, inr in self.inrs.items() }, 'optimizers': { sample_id: opt.state_dict() for sample_id, opt in self.optimizers.items() } } torch.save(state, os.path.join(self.save_dir, 'inr_database.pt')) def load_state(self, create_fn, device): """Load INRs and optimizer states from disk.""" path = os.path.join(self.save_dir, 'inr_database.pt') if os.path.exists(path): state = torch.load(path, map_location=device) # Restore INRs for sample_id, inr_state in state['inrs'].items(): inr = create_fn().to(device) inr.load_state_dict(inr_state) self.inrs[sample_id] = inr # Restore optimizers for sample_id, opt_state in state['optimizers'].items(): optimizer = torch.optim.Adam(self.inrs[sample_id].parameters()) optimizer.load_state_dict(opt_state) self.optimizers[sample_id] = optimizer class INRTrainer(Trainer): def __init__(self, hidden_features=128, n_layers=3, in_features=1, out_features=1, num_steps=5000, lr=1e-3, inr_criterion=torch.nn.MSELoss(), save_dir='./inr_database', *args, **kwargs): super().__init__(*args, **kwargs) self.hidden_features = hidden_features self.n_layers = n_layers self.in_features = in_features self.out_features = out_features self.num_steps = num_steps self.lr = lr self.inr_criterion = inr_criterion # Initialize INR database self.db = INRDatabase(save_dir) # Load existing INRs if available self.db.load_state(self.create_inr, self.device) def create_inr(self): """Factory function to create new INR instances.""" return INR( hidden_features=self.hidden_features, n_layers=self.n_layers, in_features=self.in_features, out_features=self.out_features ) def create_kan(self): return FasterKAN(layers_hidden=[self.in_features] + [self.hidden_features] * (self.n_layers) + [self.out_features],) def get_sample_id(self, batch, idx): """Extract unique identifier for a sample in the batch. Override this method based on your data structure.""" # Example: if your batch contains unique IDs if 'id' in batch: return batch['id'][idx] # Fallback: create hash from the sample data sample_data = batch['audio']['array'][idx] return hash(sample_data.cpu().numpy().tobytes()) def train_inr(self, optimizer, model, coords, values, num_iters=10, plot=False): # pbar = tqdm(range(num_iters)) for _ in range(num_iters): optimizer.zero_grad() pred_values = model(coords.to(self.device)).float() loss = self.inr_criterion(pred_values.squeeze(), values) loss.backward() optimizer.step() # pbar.set_description(f'loss: {loss.item()}') if plot: plt.plot(values.cpu().detach().numpy()) plt.plot(pred_values.cpu().detach().numpy()) plt.title(loss.item()) plt.show() return model, optimizer def train_batch(self, batch, batch_idx, device): """Train INRs for each sample in batch, persisting progress.""" coords = batch['audio']['coords'].to(device) # [B, N, 1] fft = batch['audio']['fft_mag'].to(device) # [B, N] audio = batch['audio']['array'].to(device) # [B, N] y = batch['label'].to(device).float() batch_size = coords.shape[0] values = audio batch_losses = [] batch_optimizers = [] batch_inrs = [] batch_weights = tuple() batch_biases = tuple() # Training loop # pbar = tqdm(range(self.num_steps), desc="Training INRs") plot = batch_idx == 0 for i in range(batch_size): sample_id = self.get_sample_id(batch, i) inr, optimizer = self.db.get_or_create_inr(sample_id, self.create_inr, device) inr, optimizer = self.train_inr(optimizer, inr, coords[i], values[i]) self.db.set_inr(sample_id, inr, optimizer) # pred_values = inr(coords[i]).squeeze() # batch_losses.append(self.inr_criterion(pred_values, values[i])) # batch_optimizers.append(optimizer) state_dict = inr.state_dict() weights = tuple( [v.permute(1, 0).unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "weight" in w] ) biases = tuple([v.unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "bias" in w]) if not len(batch_weights): batch_weights = weights else: batch_weights = tuple( [torch.cat((weights[i], batch_weights[i]), dim=0) for i in range(len(weights))] ) if not len(batch_biases): batch_biases = biases else: batch_biases = tuple( [torch.cat((biases[i], batch_biases[i]), dim=0) for i in range(len(biases))] ) # loss_preds = torch.tensor([0]) # acc = 0 y_pred = self.model(inputs=(batch_weights, batch_biases)).squeeze() loss_preds = self.criterion(y_pred, y) self.optimizer.zero_grad() loss_preds.backward() self.optimizer.step() # for i in range(batch_size): # batch_optimizers[i].zero_grad() # batch_losses[i] += loss_preds # batch_losses[i].backward() # batch_optimizers[i].step() if batch_idx % 10 == 0: # Adjust frequency as needed self.db.save_state() probs = torch.sigmoid(y_pred) cls_pred = (probs > 0.5).float() acc = (cls_pred == y).sum() return loss_preds, acc, y def eval_batch(self, batch, batch_idx, device): """Evaluate INRs for each sample in batch.""" coords = batch['audio']['coords'].to(device) fft = batch['audio']['fft_mag'].to(device) audio = batch['audio']['array'].to(device) batch_size = coords.shape[0] # values = torch.cat(( # audio.unsqueeze(-1), # fft.unsqueeze(-1) # ), dim=-1) values = audio # Get INRs for each sample batch_inrs = [] for i in range(batch_size): sample_id = self.get_sample_id(batch, i) inr, _ = self.db.get_or_create_inr(sample_id, self.create_inr, device) batch_inrs.append(inr) # Evaluate with torch.no_grad(): all_preds = torch.stack([ inr(coords[i]) for i, inr in enumerate(batch_inrs) ]) batch_losses = torch.stack([ self.criterion(all_preds[i].squeeze(), values[i]) for i in range(batch_size) ]) avg_loss = batch_losses.mean().item() acc = torch.zeros(self.output_dim, device=device) y = values return torch.tensor(avg_loss), acc, y def verify_parallel_gradient_isolation(trainer, batch_size=4, sequence_length=1000): """ Verify that gradients remain isolated in parallel training. """ device = trainer.device # Create test data coords = torch.linspace(0, 1, sequence_length).unsqueeze(-1) # [N, 1] coords = coords.unsqueeze(0).repeat(batch_size, 1, 1) # [B, N, 1] # Create synthetic signals targets = torch.stack([ torch.sin(2 * torch.pi * (i + 1) * coords.squeeze(-1)) for i in range(batch_size) ]).to(device) # Create batch of INRs inrs = trainer.create_batch_inrs() # Store initial parameters initial_params = [{name: param.clone().detach() for name, param in inr.named_parameters()} for inr in inrs] # Create mock batch batch = { 'audio': { 'coords': coords.to(device), 'fft_mag': targets.clone(), 'array': targets.clone() } } # Run one training step trainer.train_batch(batch, 0, device) # Verify parameter changes isolation_verified = True for i, inr in enumerate(inrs): params_changed = False for name, param in inr.named_parameters(): if not torch.allclose(param, initial_params[i][name]): params_changed = True # Verify that the changes are unique to this INR for j, other_inr in enumerate(inrs): if i != j: other_param = dict(other_inr.named_parameters())[name] if not torch.allclose(other_param, initial_params[j][name]): isolation_verified = False print(f"Warning: Parameter {name} of INR {j} changed when only INR {i} should have changed") return isolation_verified class XGBoostTrainer(): def __init__(self, model_args, train_ds, val_ds, test_ds): self.train_ds = train_ds self.test_ds = test_ds print("creating train dataframe...") self.x_train, self.y_train = self.create_dataframe(train_ds, save_name='train') print("creating validation dataframe...") self.x_val, self.y_val = self.create_dataframe(val_ds, save_name='val') print("creating test dataframe...") self.x_test, self.y_test = self.create_dataframe(test_ds, save_name='test') # Convert the data to DMatrix format self.dtrain = xgb.DMatrix(self.x_train, label=self.y_train) self.dval = xgb.DMatrix(self.x_val, label=self.y_val) self.dtest = xgb.DMatrix(self.x_test, label=self.y_test) # Model initialization self.model_args = model_args self.model = xgb.XGBClassifier(**model_args) def create_dataframe(self, ds, save_name='train'): try: df = pd.read_csv(f"tasks/utils/dfs/{save_name}.csv") except FileNotFoundError: data = [] # Iterate over the dataset pbar = tqdm(enumerate(ds)) for i, batch in pbar: label = batch['label'] features = batch['audio']['features'] # Flatten the nested dictionary structure feature_dict = {'label': label} for k, v in features.items(): if isinstance(v, dict): for sub_k, sub_v in v.items(): feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean) data.append(feature_dict) # Convert to DataFrame df = pd.DataFrame(data) print(os.getcwd()) df.to_csv(f"tasks/utils/dfs/{save_name}.csv", index=False) X = df.drop(columns=['label']) y = df['label'] return X, y def fit(self): # Train using the `train` method with early stopping params = { 'objective': 'binary:logistic', 'eval_metric': 'logloss', **self.model_args } evals_result = {} num_boost_round = 1000 # Set a large number of boosting rounds # Watchlist to monitor performance on train and validation data watchlist = [(self.dtrain, 'train'), (self.dval, 'eval')] # Train the model self.model = xgb.train( params, self.dtrain, num_boost_round=num_boost_round, evals=watchlist, early_stopping_rounds=10, # Early stopping after 10 rounds with no improvement evals_result=evals_result, verbose_eval=True # Show evaluation results for each iteration ) return evals_result def train_xgboost_in_batches(self, dataloader, eval_metric="logloss"): evals_result = {} for i, batch in enumerate(dataloader): # Convert batch data to NumPy arrays X_batch = torch.cat([batch[key].view(batch[key].size(0), -1) for key in batch if key != "label"], dim=1).numpy() y_batch = batch["label"].numpy() # Create DMatrix for XGBoost dtrain = xgb.DMatrix(X_batch, label=y_batch) # Use `train` with each batch self.model = xgb.train( params, dtrain, num_boost_round=1000, # Use a large number of rounds evals=[(self.dval, 'eval')], eval_metric=eval_metric, early_stopping_rounds=10, evals_result=evals_result, verbose_eval=False # Avoid printing every iteration ) # Optionally print progress if i % 10 == 0: print(f"Batch {i + 1}/{len(dataloader)} processed.") return evals_result def predict(self): # Predict probabilities for class 1 y_prob = self.model.predict(self.dtest, output_margin=False) # Convert probabilities to binary labels (0 or 1) using a threshold (e.g., 0.5) y_pred = (y_prob >= 0.5).astype(int) # Evaluate performance accuracy = accuracy_score(self.y_test, y_pred) roc_auc = roc_auc_score(self.y_test, y_prob) print(f'Accuracy: {accuracy:.4f}') print(f'ROC AUC Score: {roc_auc:.4f}') print(classification_report(self.y_test, y_pred)) def plot_results(self, evals_result): train_logloss = evals_result["train"]["logloss"] val_logloss = evals_result["eval"]["logloss"] iterations = range(1, len(train_logloss) + 1) # Plot plt.figure(figsize=(8, 5)) plt.plot(iterations, train_logloss, label="Train LogLoss", color="blue") plt.plot(iterations, val_logloss, label="Validation LogLoss", color="red") plt.xlabel("Boosting Round (Iteration)") plt.ylabel("Log Loss") plt.title("Training and Validation Log Loss over Iterations") plt.legend() plt.grid() plt.show()