import torch from tqdm import tqdm # training function def train(model, dataloader, optimizer, criterion, train_data, device): print('Training') model.train() counter = 0 train_running_loss = 0.0 for i, data in tqdm(enumerate(dataloader), total=int(len(train_data)/dataloader.batch_size)): counter += 1 data, target = data['image'].to(device), data['label'].to(device) optimizer.zero_grad() outputs = model(data) # apply sigmoid activation to get all the outputs between 0 and 1 outputs = torch.sigmoid(outputs) loss = criterion(outputs, target) train_running_loss += loss.item() # backpropagation loss.backward() # update optimizer parameters optimizer.step() train_loss = train_running_loss / counter return train_loss # validation function def validate(model, dataloader, criterion, val_data, device): print('Validating') model.eval() counter = 0 val_running_loss = 0.0 with torch.no_grad(): for i, data in tqdm(enumerate(dataloader), total=int(len(val_data)/dataloader.batch_size)): counter += 1 data, target = data['image'].to(device), data['label'].to(device) outputs = model(data) # apply sigmoid activation to get all the outputs between 0 and 1 outputs = torch.sigmoid(outputs) loss = criterion(outputs, target) val_running_loss += loss.item() val_loss = val_running_loss / counter return val_loss