Dabs's picture
first commit
cb8043e
import torch
from tqdm import tqdm
# training function
def train(model, dataloader, optimizer, criterion, train_data, device):
print('Training')
model.train()
counter = 0
train_running_loss = 0.0
for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)):
counter += 1
data, target = data['image'].to(device), data['label'].to(device)
optimizer.zero_grad()
outputs = model(data)
# apply sigmoid activation to get all the outputs between 0 and 1
outputs = torch.sigmoid(outputs)
loss = criterion(outputs, target)
train_running_loss += loss.item()
# backpropagation
loss.backward()
# update optimizer parameters
optimizer.step()
train_loss = train_running_loss / counter
return train_loss
# validation function
def validate(model, dataloader, criterion, val_data, device):
print('Validating')
model.eval()
counter = 0
val_running_loss = 0.0
with torch.no_grad():
for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)):
counter += 1
data, target = data['image'].to(device), data['label'].to(device)
outputs = model(data)
# apply sigmoid activation to get all the outputs between 0 and 1
outputs = torch.sigmoid(outputs)
loss = criterion(outputs, target)
val_running_loss += loss.item()
val_loss = val_running_loss / counter
return val_loss