|
import torch |
|
import wandb |
|
from tqdm.auto import tqdm |
|
|
|
API_KEY = "881252af31786a1cf813449b9b4124955f54703e" |
|
|
|
|
|
def train_loop(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, |
|
optimizer: torch.optim.Optimizer, accuracy_fn, device: torch.device): |
|
|
|
""" |
|
Function for Model Training |
|
|
|
Args: |
|
model: A pytorch model you want to train |
|
dataloader: A dataloader for intance for model training |
|
loss_fn: A loss function for calculate model loss |
|
accuracy_fn: A Accuracy function to check how model is accurate |
|
device: A device on which model run i.e.: "cuda" or "cpu" |
|
|
|
Return: |
|
list of train loss and accuracy and also model weights |
|
|
|
Example usage: |
|
train_loop(model = mymodel, dataloader = train_dataloader, loss_fn = loss_fn, |
|
accuracy_fn = accuracy_fn, device = device) |
|
""" |
|
|
|
train_loss, train_acc = 0, 0 |
|
|
|
model.train() |
|
|
|
for batch, (x_train, y_train) in enumerate(dataloader): |
|
x_train, y_train = x_train.to(device), y_train.to(device) |
|
|
|
|
|
logits = model(x_train) |
|
|
|
|
|
loss = loss_fn(logits, y_train) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
optimizer.step() |
|
|
|
acc = accuracy_fn(torch.argmax(logits, dim = 1), y_train) |
|
|
|
train_acc += acc |
|
train_loss += loss |
|
|
|
train_loss /= len(dataloader) |
|
train_acc /= len(dataloader) |
|
|
|
return model, train_loss, train_acc |
|
|
|
|
|
def test_loop(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, |
|
accuracy_fn, device: torch.device): |
|
|
|
""" |
|
A Funtion to test the model after traininig |
|
|
|
Args: |
|
model: A model which you want to test on test intance |
|
dataloader: A dataloader intance on which you test model |
|
loss_fn: A loss function to calculate the model loss |
|
accuracy_fn : A accuracy function to check model accuracy on dataloader intance |
|
device: A device on whic you want to run model i.e.: "cuda" or "cpu" |
|
|
|
Return: |
|
A list of test loss and Accuracy |
|
|
|
Example Usage: |
|
test_loop(model = mymodel, dataloader = test_datloader, loss_fn = loss_fn, |
|
accuracy_fn = accuracy+fn, device = device) |
|
""" |
|
|
|
test_loss, test_acc = 0, 0 |
|
|
|
model.eval() |
|
with torch.inference_mode(): |
|
for x_test, y_test in dataloader: |
|
x_test, y_test = x_test.to(device), y_test.to(device) |
|
|
|
|
|
logits = model(x_test) |
|
|
|
|
|
test_loss += loss_fn(logits, y_test) |
|
|
|
test_acc += accuracy_fn(torch.argmax(logits, dim = 1), y_test) |
|
|
|
test_acc /= len(dataloader) |
|
test_loss /= len(dataloader) |
|
|
|
return test_loss, test_acc |
|
|
|
|
|
def log_image_table(images, predicted, labels, probs): |
|
"Log a wandb.Table with (img, pred, target, scores)" |
|
wandb.login(key=API_KEY) |
|
print("[LOG]: Login Succesfull.") |
|
|
|
table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)]) |
|
for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")): |
|
table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy()) |
|
wandb.log({"predictions_table":table}, commit=False) |
|
|
|
|
|
|
|
def validation(model: torch.nn.Module, dataloader: torch.utils.data.DataLoader, loss_fn: torch.nn.Module, |
|
accuracy_fn, log_images: bool, device: torch.device, batch_idx=0): |
|
""" |
|
A Function for hyperparameter tuning |
|
|
|
Args: |
|
model: A model the you want tune its hyperparameter |
|
dataloder: Adataloader intance for hyperparameter tuning |
|
loss_fn: A loss funtion to calcualte model loss |
|
Accuracy_fn: a accuracy function to calcultae accuracy for model perforamnce |
|
device: A device on which model run i.e.: "cuda" or "cpu" |
|
|
|
Return: |
|
A list of accuracy and loss |
|
|
|
Example usage: |
|
validation(model = mymodel, dataloader = valid_dataloader, loss_fn = loss_fn, |
|
accuracy_fn = accuracy_fn, device = device) |
|
""" |
|
|
|
val_loss, val_acc = 0, 0 |
|
|
|
model.eval() |
|
with torch.inference_mode(): |
|
for i, (x_val, y_val) in enumerate(dataloader): |
|
x_val, y_val = x_val.to(device), y_val.to(device) |
|
|
|
logits = model(x_val) |
|
|
|
val_loss += loss_fn(logits, y_val) |
|
|
|
val_acc += accuracy_fn(torch.argmax(logits, dim = 1), y_val) |
|
|
|
|
|
if i==batch_idx and log_images: |
|
log_image_table(x_val, torch.max(logits.data, 1)[0], y_val, logits.softmax(dim=1)) |
|
|
|
val_loss /= len(dataloader) |
|
val_acc /= len(dataloader) |
|
|
|
return val_loss, val_acc |
|
|
|
def train(model: torch.nn.Module, train_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader, |
|
optimizer: torch.optim.Optimizer, loss_fn: torch.nn.Module, accuracy_fn, epochs: int, device: torch.device): |
|
|
|
""" |
|
A function to train and test the pytorch model |
|
|
|
Args: |
|
model: A model to train and test |
|
train_dataloader: A dataloader intance for train model |
|
test_dataloader: A dataloader intance for test model |
|
optimizer: A optimizer funtion to optimize the model |
|
loss_fn: A loss function to calculate model loss |
|
accuracy_fn: An accuracy to calculate model performance |
|
epochs: number of iteration to run the loop |
|
device: A device on which model run i.e.: "cuda" or "cpu" |
|
|
|
Return: |
|
train model, List of train and test losses and accuracy |
|
|
|
Example usage: |
|
train(model = mymodel, train_dataloader = train_dataloader, test_dataloader = test_dataloader, optimizer = optimizer, |
|
loss_fn = loss_fn, acuuracy_fn = accuracy_fn, epochs = epochs, device = device) |
|
""" |
|
|
|
train_losses, test_losses = [], [] |
|
train_accs, test_accs = [], [] |
|
|
|
for epoch in tqdm(range(epochs)): |
|
|
|
print(f"\nEpoch: {epoch+1}") |
|
|
|
train_model, train_loss, train_acc = train_loop(model = model, dataloader = train_dataloader, |
|
loss_fn = loss_fn, optimizer = optimizer, |
|
accuracy_fn = accuracy_fn, device = device) |
|
|
|
test_loss, test_acc = test_loop(model = model, dataloader = test_dataloader, loss_fn = loss_fn, |
|
accuracy_fn = accuracy_fn, device = device) |
|
|
|
print(f"Train Loss: {train_loss:.5f} | Test Loss: {test_loss:.5f} || Train Accuracy: {train_acc:.5f} | Test Accuracy: {test_acc:.5f}") |
|
|
|
train_losses.append(train_loss.item()) |
|
test_losses.append(test_loss.item()) |
|
train_accs.append(train_acc.item()) |
|
test_accs.append(test_acc.item()) |
|
|
|
return train_losses, test_losses, train_accs, test_accs, train_model |