Spaces:
Sleeping
Sleeping
import torch | |
from torch.cuda.amp import autocast | |
import numpy as np | |
import time | |
import os | |
import yaml | |
from matplotlib import pyplot as plt | |
import glob | |
from collections import OrderedDict | |
from tqdm import tqdm | |
import torch.distributed as dist | |
import pandas as pd | |
import xgboost as xgb | |
from sklearn.metrics import accuracy_score, classification_report, roc_auc_score | |
from torch.nn import ModuleList | |
# from inr import INR | |
# from kan import FasterKAN | |
class Trainer(object): | |
""" | |
A class that encapsulates the training loop for a PyTorch model. | |
""" | |
def __init__(self, model, optimizer, criterion, train_dataloader, device, world_size=1, output_dim=2, | |
scheduler=None, val_dataloader=None, max_iter=np.inf, scaler=None, | |
grad_clip=False, exp_num=None, log_path=None, exp_name=None, plot_every=None, | |
cos_inc=False, range_update=None, accumulation_step=1, wandb_log=False, num_quantiles=1, | |
update_func=lambda x: x): | |
self.model = model | |
self.optimizer = optimizer | |
self.criterion = criterion | |
self.scaler = scaler | |
self.grad_clip = grad_clip | |
self.cos_inc = cos_inc | |
self.output_dim = output_dim | |
self.scheduler = scheduler | |
self.train_dl = train_dataloader | |
self.val_dl = val_dataloader | |
self.train_sampler = self.get_sampler_from_dataloader(train_dataloader) | |
self.val_sampler = self.get_sampler_from_dataloader(val_dataloader) | |
self.max_iter = max_iter | |
self.device = device | |
self.world_size = world_size | |
self.exp_num = exp_num | |
self.exp_name = exp_name | |
self.log_path = log_path | |
self.best_state_dict = None | |
self.plot_every = plot_every | |
self.logger = None | |
self.range_update = range_update | |
self.accumulation_step = accumulation_step | |
self.wandb = wandb_log | |
self.num_quantiles = num_quantiles | |
self.update_func = update_func | |
# if log_path is not None: | |
# self.logger =SummaryWriter(f'{self.log_path}/exp{self.exp_num}') | |
# # print(f"logger path: {self.log_path}/exp{self.exp_num}") | |
# print("logger is: ", self.logger) | |
def get_sampler_from_dataloader(self, dataloader): | |
if hasattr(dataloader, 'sampler'): | |
if isinstance(dataloader.sampler, torch.utils.data.DistributedSampler): | |
return dataloader.sampler | |
elif hasattr(dataloader.sampler, 'sampler'): | |
return dataloader.sampler.sampler | |
if hasattr(dataloader, 'batch_sampler') and hasattr(dataloader.batch_sampler, 'sampler'): | |
return dataloader.batch_sampler.sampler | |
return None | |
def fit(self, num_epochs, device, early_stopping=None, only_p=False, best='loss', conf=False): | |
""" | |
Fits the model for the given number of epochs. | |
""" | |
min_loss = np.inf | |
best_acc = 0 | |
train_loss, val_loss, = [], [] | |
train_acc, val_acc = [], [] | |
lrs = [] | |
# self.optim_params['lr_history'] = [] | |
epochs_without_improvement = 0 | |
# main_proccess = (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or self.device == 'cpu' | |
main_proccess = True # change in a ddp setting | |
print(f"Starting training for {num_epochs} epochs") | |
print("is main process: ", main_proccess, flush=True) | |
global_time = time.time() | |
self.epoch = 0 | |
for epoch in range(num_epochs): | |
self.epoch = epoch | |
start_time = time.time() | |
plot = (self.plot_every is not None) and (epoch % self.plot_every == 0) | |
t_loss, t_acc = self.train_epoch(device, epoch=epoch) | |
t_loss_mean = np.nanmean(t_loss) | |
train_loss.extend(t_loss) | |
global_train_accuracy, global_train_loss = self.process_loss(t_acc, t_loss_mean) | |
if main_proccess: # Only perform this on the master GPU | |
train_acc.append(global_train_accuracy.mean().item()) | |
v_loss, v_acc = self.eval_epoch(device, epoch=epoch) | |
v_loss_mean = np.nanmean(v_loss) | |
val_loss.extend(v_loss) | |
global_val_accuracy, global_val_loss = self.process_loss(v_acc, v_loss_mean) | |
if main_proccess: # Only perform this on the master GPU | |
val_acc.append(global_val_accuracy.mean().item()) | |
current_objective = global_val_loss if best == 'loss' else global_val_accuracy.mean() | |
improved = False | |
if best == 'loss': | |
if current_objective < min_loss: | |
min_loss = current_objective | |
improved = True | |
else: | |
if current_objective > best_acc: | |
best_acc = current_objective | |
improved = True | |
if improved: | |
model_name = f'{self.log_path}/{self.exp_num}/{self.exp_name}.pth' | |
print(f"saving model at {model_name}...") | |
torch.save(self.model.state_dict(), model_name) | |
self.best_state_dict = self.model.state_dict() | |
epochs_without_improvement = 0 | |
else: | |
epochs_without_improvement += 1 | |
current_lr = self.optimizer.param_groups[0]['lr'] if self.scheduler is None \ | |
else self.scheduler.get_last_lr()[0] | |
lrs.append(current_lr) | |
print(f'Epoch {epoch}, lr {current_lr}, Train Loss: {global_train_loss:.6f}, Val Loss:'\ | |
f'{global_val_loss:.6f}, Train Acc: {global_train_accuracy.round(decimals=4).tolist()}, '\ | |
f'Val Acc: {global_val_accuracy.round(decimals=4).tolist()},'\ | |
f'Time: {time.time() - start_time:.2f}s, Total Time: {(time.time() - global_time)/3600} hr', flush=True) | |
if epoch % 10 == 0: | |
print(os.system('nvidia-smi')) | |
if epochs_without_improvement == early_stopping: | |
print('early stopping!', flush=True) | |
break | |
if time.time() - global_time > (23.83 * 3600): | |
print("time limit reached") | |
break | |
return {"num_epochs":num_epochs, "train_loss": train_loss, | |
"val_loss": val_loss, "train_acc": train_acc, "val_acc": val_acc, "lrs": lrs} | |
def process_loss(self, acc, loss_mean): | |
if torch.cuda.is_available() and torch.distributed.is_initialized(): | |
global_accuracy = torch.tensor(acc).cuda() # Convert accuracy to a tensor on the GPU | |
torch.distributed.reduce(global_accuracy, dst=0, op=torch.distributed.ReduceOp.SUM) | |
global_loss = torch.tensor(loss_mean).cuda() # Convert loss to a tensor on the GPU | |
torch.distributed.reduce(global_loss, dst=0, op=torch.distributed.ReduceOp.SUM) | |
# Divide both loss and accuracy by world size | |
world_size = torch.distributed.get_world_size() | |
global_loss /= world_size | |
global_accuracy /= world_size | |
else: | |
global_loss = torch.tensor(loss_mean) | |
global_accuracy = torch.tensor(acc) | |
return global_accuracy, global_loss | |
def load_best_model(self, to_ddp=True, from_ddp=True): | |
data_dir = f'{self.log_path}/exp{self.exp_num}' | |
# data_dir = f'{self.log_path}/exp29' # for debugging | |
state_dict_files = glob.glob(data_dir + '/*.pth') | |
print("loading model from ", state_dict_files[-1]) | |
state_dict = torch.load(state_dict_files[-1]) if to_ddp else torch.load(state_dict_files[0],map_location=self.device) | |
if from_ddp: | |
print("loading distributed model") | |
# Remove "module." from keys | |
new_state_dict = OrderedDict() | |
for key, value in state_dict.items(): | |
if key.startswith('module.'): | |
while key.startswith('module.'): | |
key = key[7:] | |
new_state_dict[key] = value | |
state_dict = new_state_dict | |
# print("state_dict: ", state_dict.keys()) | |
# print("model: ", self.model.state_dict().keys()) | |
self.model.load_state_dict(state_dict, strict=False) | |
def check_gradients(self): | |
for name, param in self.model.named_parameters(): | |
if param.grad is not None: | |
grad_norm = param.grad.norm().item() | |
if grad_norm > 10: | |
print(f"Large gradient in {name}: {grad_norm}") | |
def train_epoch(self, device, epoch): | |
""" | |
Trains the model for one epoch. | |
""" | |
if self.train_sampler is not None: | |
try: | |
self.train_sampler.set_epoch(epoch) | |
except AttributeError: | |
pass | |
self.model.train() | |
train_loss = [] | |
train_acc = 0 | |
total = 0 | |
all_accs = torch.zeros(self.output_dim, device=device) | |
pbar = tqdm(self.train_dl) | |
for i, batch in enumerate(pbar): | |
if self.optimizer is not None: | |
self.optimizer.zero_grad() | |
loss, acc , y = self.train_batch(batch, i, device) | |
train_loss.append(loss.item()) | |
all_accs = all_accs + acc | |
total += len(y) | |
pbar.set_description(f"train_acc: {acc}, train_loss: {loss.item()}") | |
if i > self.max_iter: | |
break | |
print("number of train_accs: ", train_acc) | |
return train_loss, all_accs/total | |
def train_batch(self, batch, batch_idx, device): | |
x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] | |
# features = torch.stack(batch['audio']['features']).to(device).float() | |
# cwt = batch['audio']['cwt_mag'] | |
x = x.to(device).float() | |
fft = fft.to(device).float() | |
# cwt = cwt.to(device).float() | |
y = y.to(device).float() | |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
y_pred = self.model(x_fft).squeeze() | |
loss = self.criterion(y_pred, y) | |
loss.backward() | |
self.optimizer.step() | |
if self.scheduler is not None: | |
self.scheduler.step() | |
# get predicted classes | |
probs = torch.sigmoid(y_pred) | |
cls_pred = (probs > 0.5).float() | |
acc = (cls_pred == y).sum() | |
return loss, acc, y | |
def eval_epoch(self, device, epoch): | |
""" | |
Evaluates the model for one epoch. | |
""" | |
self.model.eval() | |
val_loss = [] | |
val_acc = 0 | |
total = 0 | |
all_accs = torch.zeros(self.output_dim, device=device) | |
pbar = tqdm(self.val_dl) | |
for i,batch in enumerate(pbar): | |
loss, acc, y = self.eval_batch(batch, i, device) | |
val_loss.append(loss.item()) | |
all_accs = all_accs + acc | |
total += len(y) | |
pbar.set_description(f"val_acc: {acc}, val_loss: {loss.item()}") | |
if i > self.max_iter: | |
break | |
return val_loss, all_accs/total | |
def eval_batch(self, batch, batch_idx, device): | |
x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] | |
# features = torch.stack(batch['audio']['features']).to(device).float() | |
# features = batch['audio']['features_arr'].to(device).float() | |
x = x.to(device).float() | |
fft = fft.to(device).float() | |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
y = y.to(device).float() | |
with torch.no_grad(): | |
y_pred = self.model(x_fft).squeeze() | |
loss = self.criterion(y_pred.squeeze(), y) | |
probs = torch.sigmoid(y_pred) | |
cls_pred = (probs > 0.5).float() | |
acc = (cls_pred == y).sum() | |
return loss, acc, y | |
def predict(self, test_dataloader, device): | |
""" | |
Returns the predictions of the model on the given dataset. | |
""" | |
self.model.eval() | |
total = 0 | |
all_accs = 0 | |
predictions = [] | |
true_labels = [] | |
pbar = tqdm(test_dataloader) | |
for i,batch in enumerate(pbar): | |
x, fft, y = batch['audio']['array'], batch['audio']['fft_mag'], batch['label'] | |
# features = batch['audio']['features'] | |
x = x.to(device).float() | |
fft = fft.to(device).float() | |
x_fft = torch.cat((x.unsqueeze(dim=1), fft.unsqueeze(dim=1)), dim=1) | |
y = y.to(device).float() | |
with torch.no_grad(): | |
y_pred = self.model(x_fft).squeeze() | |
loss = self.criterion(y_pred, y) | |
probs = torch.sigmoid(y_pred) | |
cls_pred = (probs > 0.5).float() | |
acc = (cls_pred == y).sum() | |
predictions.extend(cls_pred.cpu().numpy()) | |
true_labels.extend(y.cpu().numpy().astype(np.int64)) | |
all_accs += acc | |
total += len(y) | |
pbar.set_description("acc: {:.4f}".format(acc)) | |
if i > self.max_iter: | |
break | |
return predictions, true_labels, all_accs/total | |
class INRDatabase: | |
"""Database to store and manage INRs persistently.""" | |
def __init__(self, save_dir='./inr_database'): | |
self.inrs = {} # Maps sample_id -> INR | |
self.optimizers = {} # Maps sample_id -> optimizer state | |
self.save_dir = save_dir | |
os.makedirs(save_dir, exist_ok=True) | |
def get_or_create_inr(self, sample_id, create_fn, device): | |
"""Get existing INR or create new one if not exists.""" | |
if sample_id not in self.inrs: | |
# Create new INR | |
inr = create_fn().to(device) | |
optimizer = torch.optim.Adam(inr.parameters()) | |
self.inrs[sample_id] = inr | |
self.optimizers[sample_id] = optimizer | |
return self.inrs[sample_id], self.optimizers[sample_id] | |
def set_inr(self, sample_id, inr, optimizer): | |
self.inrs[sample_id] = inr | |
self.optimizers[sample_id] = optimizer | |
def save_state(self): | |
"""Save all INRs and optimizer states to disk.""" | |
state = { | |
'inrs': { | |
sample_id: inr.state_dict() | |
for sample_id, inr in self.inrs.items() | |
}, | |
'optimizers': { | |
sample_id: opt.state_dict() | |
for sample_id, opt in self.optimizers.items() | |
} | |
} | |
torch.save(state, os.path.join(self.save_dir, 'inr_database.pt')) | |
def load_state(self, create_fn, device): | |
"""Load INRs and optimizer states from disk.""" | |
path = os.path.join(self.save_dir, 'inr_database.pt') | |
if os.path.exists(path): | |
state = torch.load(path, map_location=device) | |
# Restore INRs | |
for sample_id, inr_state in state['inrs'].items(): | |
inr = create_fn().to(device) | |
inr.load_state_dict(inr_state) | |
self.inrs[sample_id] = inr | |
# Restore optimizers | |
for sample_id, opt_state in state['optimizers'].items(): | |
optimizer = torch.optim.Adam(self.inrs[sample_id].parameters()) | |
optimizer.load_state_dict(opt_state) | |
self.optimizers[sample_id] = optimizer | |
class INRTrainer(Trainer): | |
def __init__(self, hidden_features=128, n_layers=3, in_features=1, out_features=1, | |
num_steps=5000, lr=1e-3, inr_criterion=torch.nn.MSELoss(), save_dir='./inr_database', *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.hidden_features = hidden_features | |
self.n_layers = n_layers | |
self.in_features = in_features | |
self.out_features = out_features | |
self.num_steps = num_steps | |
self.lr = lr | |
self.inr_criterion = inr_criterion | |
# Initialize INR database | |
self.db = INRDatabase(save_dir) | |
# Load existing INRs if available | |
self.db.load_state(self.create_inr, self.device) | |
def create_inr(self): | |
"""Factory function to create new INR instances.""" | |
return INR( | |
hidden_features=self.hidden_features, | |
n_layers=self.n_layers, | |
in_features=self.in_features, | |
out_features=self.out_features | |
) | |
def create_kan(self): | |
return FasterKAN(layers_hidden=[self.in_features] + [self.hidden_features] * (self.n_layers) + [self.out_features],) | |
def get_sample_id(self, batch, idx): | |
"""Extract unique identifier for a sample in the batch. | |
Override this method based on your data structure.""" | |
# Example: if your batch contains unique IDs | |
if 'id' in batch: | |
return batch['id'][idx] | |
# Fallback: create hash from the sample data | |
sample_data = batch['audio']['array'][idx] | |
return hash(sample_data.cpu().numpy().tobytes()) | |
def train_inr(self, optimizer, model, coords, values, num_iters=10, plot=False): | |
# pbar = tqdm(range(num_iters)) | |
for _ in range(num_iters): | |
optimizer.zero_grad() | |
pred_values = model(coords.to(self.device)).float() | |
loss = self.inr_criterion(pred_values.squeeze(), values) | |
loss.backward() | |
optimizer.step() | |
# pbar.set_description(f'loss: {loss.item()}') | |
if plot: | |
plt.plot(values.cpu().detach().numpy()) | |
plt.plot(pred_values.cpu().detach().numpy()) | |
plt.title(loss.item()) | |
plt.show() | |
return model, optimizer | |
def train_batch(self, batch, batch_idx, device): | |
"""Train INRs for each sample in batch, persisting progress.""" | |
coords = batch['audio']['coords'].to(device) # [B, N, 1] | |
fft = batch['audio']['fft_mag'].to(device) # [B, N] | |
audio = batch['audio']['array'].to(device) # [B, N] | |
y = batch['label'].to(device).float() | |
batch_size = coords.shape[0] | |
values = audio | |
batch_losses = [] | |
batch_optimizers = [] | |
batch_inrs = [] | |
batch_weights = tuple() | |
batch_biases = tuple() | |
# Training loop | |
# pbar = tqdm(range(self.num_steps), desc="Training INRs") | |
plot = batch_idx == 0 | |
for i in range(batch_size): | |
sample_id = self.get_sample_id(batch, i) | |
inr, optimizer = self.db.get_or_create_inr(sample_id, self.create_inr, device) | |
inr, optimizer = self.train_inr(optimizer, inr, coords[i], values[i]) | |
self.db.set_inr(sample_id, inr, optimizer) | |
# pred_values = inr(coords[i]).squeeze() | |
# batch_losses.append(self.inr_criterion(pred_values, values[i])) | |
# batch_optimizers.append(optimizer) | |
state_dict = inr.state_dict() | |
weights = tuple( | |
[v.permute(1, 0).unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "weight" in w] | |
) | |
biases = tuple([v.unsqueeze(-1).unsqueeze(0).to(device) for w, v in state_dict.items() if "bias" in w]) | |
if not len(batch_weights): | |
batch_weights = weights | |
else: | |
batch_weights = tuple( | |
[torch.cat((weights[i], batch_weights[i]), dim=0) for i in range(len(weights))] | |
) | |
if not len(batch_biases): | |
batch_biases = biases | |
else: | |
batch_biases = tuple( | |
[torch.cat((biases[i], batch_biases[i]), dim=0) for i in range(len(biases))] | |
) | |
# loss_preds = torch.tensor([0]) | |
# acc = 0 | |
y_pred = self.model(inputs=(batch_weights, batch_biases)).squeeze() | |
loss_preds = self.criterion(y_pred, y) | |
self.optimizer.zero_grad() | |
loss_preds.backward() | |
self.optimizer.step() | |
# for i in range(batch_size): | |
# batch_optimizers[i].zero_grad() | |
# batch_losses[i] += loss_preds | |
# batch_losses[i].backward() | |
# batch_optimizers[i].step() | |
if batch_idx % 10 == 0: # Adjust frequency as needed | |
self.db.save_state() | |
probs = torch.sigmoid(y_pred) | |
cls_pred = (probs > 0.5).float() | |
acc = (cls_pred == y).sum() | |
return loss_preds, acc, y | |
def eval_batch(self, batch, batch_idx, device): | |
"""Evaluate INRs for each sample in batch.""" | |
coords = batch['audio']['coords'].to(device) | |
fft = batch['audio']['fft_mag'].to(device) | |
audio = batch['audio']['array'].to(device) | |
batch_size = coords.shape[0] | |
# values = torch.cat(( | |
# audio.unsqueeze(-1), | |
# fft.unsqueeze(-1) | |
# ), dim=-1) | |
values = audio | |
# Get INRs for each sample | |
batch_inrs = [] | |
for i in range(batch_size): | |
sample_id = self.get_sample_id(batch, i) | |
inr, _ = self.db.get_or_create_inr(sample_id, self.create_inr, device) | |
batch_inrs.append(inr) | |
# Evaluate | |
with torch.no_grad(): | |
all_preds = torch.stack([ | |
inr(coords[i]) | |
for i, inr in enumerate(batch_inrs) | |
]) | |
batch_losses = torch.stack([ | |
self.criterion(all_preds[i].squeeze(), values[i]) | |
for i in range(batch_size) | |
]) | |
avg_loss = batch_losses.mean().item() | |
acc = torch.zeros(self.output_dim, device=device) | |
y = values | |
return torch.tensor(avg_loss), acc, y | |
def verify_parallel_gradient_isolation(trainer, batch_size=4, sequence_length=1000): | |
""" | |
Verify that gradients remain isolated in parallel training. | |
""" | |
device = trainer.device | |
# Create test data | |
coords = torch.linspace(0, 1, sequence_length).unsqueeze(-1) # [N, 1] | |
coords = coords.unsqueeze(0).repeat(batch_size, 1, 1) # [B, N, 1] | |
# Create synthetic signals | |
targets = torch.stack([ | |
torch.sin(2 * torch.pi * (i + 1) * coords.squeeze(-1)) | |
for i in range(batch_size) | |
]).to(device) | |
# Create batch of INRs | |
inrs = trainer.create_batch_inrs() | |
# Store initial parameters | |
initial_params = [{name: param.clone().detach() | |
for name, param in inr.named_parameters()} | |
for inr in inrs] | |
# Create mock batch | |
batch = { | |
'audio': { | |
'coords': coords.to(device), | |
'fft_mag': targets.clone(), | |
'array': targets.clone() | |
} | |
} | |
# Run one training step | |
trainer.train_batch(batch, 0, device) | |
# Verify parameter changes | |
isolation_verified = True | |
for i, inr in enumerate(inrs): | |
params_changed = False | |
for name, param in inr.named_parameters(): | |
if not torch.allclose(param, initial_params[i][name]): | |
params_changed = True | |
# Verify that the changes are unique to this INR | |
for j, other_inr in enumerate(inrs): | |
if i != j: | |
other_param = dict(other_inr.named_parameters())[name] | |
if not torch.allclose(other_param, initial_params[j][name]): | |
isolation_verified = False | |
print(f"Warning: Parameter {name} of INR {j} changed when only INR {i} should have changed") | |
return isolation_verified | |
class XGBoostTrainer(): | |
def __init__(self, model_args, train_ds, val_ds, test_ds): | |
self.train_ds = train_ds | |
self.test_ds = test_ds | |
print("creating train dataframe...") | |
self.x_train, self.y_train = self.create_dataframe(train_ds, save_name='train') | |
print("creating validation dataframe...") | |
self.x_val, self.y_val = self.create_dataframe(val_ds, save_name='val') | |
print("creating test dataframe...") | |
self.x_test, self.y_test = self.create_dataframe(test_ds, save_name='test') | |
# Convert the data to DMatrix format | |
self.dtrain = xgb.DMatrix(self.x_train, label=self.y_train) | |
self.dval = xgb.DMatrix(self.x_val, label=self.y_val) | |
self.dtest = xgb.DMatrix(self.x_test, label=self.y_test) | |
# Model initialization | |
self.model_args = model_args | |
self.model = xgb.XGBClassifier(**model_args) | |
def create_dataframe(self, ds, save_name='train'): | |
try: | |
df = pd.read_csv(f"tasks/utils/dfs/{save_name}.csv") | |
except FileNotFoundError: | |
data = [] | |
# Iterate over the dataset | |
pbar = tqdm(enumerate(ds)) | |
for i, batch in pbar: | |
label = batch['label'] | |
features = batch['audio']['features'] | |
# Flatten the nested dictionary structure | |
feature_dict = {'label': label} | |
for k, v in features.items(): | |
if isinstance(v, dict): | |
for sub_k, sub_v in v.items(): | |
feature_dict[f"{k}_{sub_k}"] = sub_v[0].item() # Aggregate (e.g., mean) | |
data.append(feature_dict) | |
# Convert to DataFrame | |
df = pd.DataFrame(data) | |
print(os.getcwd()) | |
df.to_csv(f"tasks/utils/dfs/{save_name}.csv", index=False) | |
X = df.drop(columns=['label']) | |
y = df['label'] | |
return X, y | |
def fit(self): | |
# Train using the `train` method with early stopping | |
params = { | |
'objective': 'binary:logistic', | |
'eval_metric': 'logloss', | |
**self.model_args | |
} | |
evals_result = {} | |
num_boost_round = 1000 # Set a large number of boosting rounds | |
# Watchlist to monitor performance on train and validation data | |
watchlist = [(self.dtrain, 'train'), (self.dval, 'eval')] | |
# Train the model | |
self.model = xgb.train( | |
params, | |
self.dtrain, | |
num_boost_round=num_boost_round, | |
evals=watchlist, | |
early_stopping_rounds=10, # Early stopping after 10 rounds with no improvement | |
evals_result=evals_result, | |
verbose_eval=True # Show evaluation results for each iteration | |
) | |
return evals_result | |
def train_xgboost_in_batches(self, dataloader, eval_metric="logloss"): | |
evals_result = {} | |
for i, batch in enumerate(dataloader): | |
# Convert batch data to NumPy arrays | |
X_batch = torch.cat([batch[key].view(batch[key].size(0), -1) for key in batch if key != "label"], | |
dim=1).numpy() | |
y_batch = batch["label"].numpy() | |
# Create DMatrix for XGBoost | |
dtrain = xgb.DMatrix(X_batch, label=y_batch) | |
# Use `train` with each batch | |
self.model = xgb.train( | |
params, | |
dtrain, | |
num_boost_round=1000, # Use a large number of rounds | |
evals=[(self.dval, 'eval')], | |
eval_metric=eval_metric, | |
early_stopping_rounds=10, | |
evals_result=evals_result, | |
verbose_eval=False # Avoid printing every iteration | |
) | |
# Optionally print progress | |
if i % 10 == 0: | |
print(f"Batch {i + 1}/{len(dataloader)} processed.") | |
return evals_result | |
def predict(self): | |
# Predict probabilities for class 1 | |
y_prob = self.model.predict(self.dtest, output_margin=False) | |
# Convert probabilities to binary labels (0 or 1) using a threshold (e.g., 0.5) | |
y_pred = (y_prob >= 0.5).astype(int) | |
# Evaluate performance | |
accuracy = accuracy_score(self.y_test, y_pred) | |
roc_auc = roc_auc_score(self.y_test, y_prob) | |
print(f'Accuracy: {accuracy:.4f}') | |
print(f'ROC AUC Score: {roc_auc:.4f}') | |
print(classification_report(self.y_test, y_pred)) | |
def plot_results(self, evals_result): | |
train_logloss = evals_result["train"]["logloss"] | |
val_logloss = evals_result["eval"]["logloss"] | |
iterations = range(1, len(train_logloss) + 1) | |
# Plot | |
plt.figure(figsize=(8, 5)) | |
plt.plot(iterations, train_logloss, label="Train LogLoss", color="blue") | |
plt.plot(iterations, val_logloss, label="Validation LogLoss", color="red") | |
plt.xlabel("Boosting Round (Iteration)") | |
plt.ylabel("Log Loss") | |
plt.title("Training and Validation Log Loss over Iterations") | |
plt.legend() | |
plt.grid() | |
plt.show() | |