Dabs's picture
first commit
cb8043e
import models
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from engine import train, validate
from dataset import ImageDataset
from torch.utils.data import DataLoader
matplotlib.style.use('ggplot')
# initialize the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
#intialize the model
model = models.model(pretrained=True, requires_grad=False).to(device)
# learning parameters
lr = 0.0001
epochs = 10
batch_size = 32
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCELoss()
# read the training csv file
train_csv = pd.read_csv('../input/movie-classifier/Multi_Label_dataset/train.csv')
# train dataset
train_data = ImageDataset(
train_csv, train=True, test=False
)
# validation dataset
valid_data = ImageDataset(
train_csv, train=False, test=False
)
# train data loader
train_loader = DataLoader(
train_data,
batch_size=batch_size,
shuffle=True
)
# validation data loader
valid_loader = DataLoader(
valid_data,
batch_size=batch_size,
shuffle=False
)
# start the training and validation
train_loss = []
valid_loss = []
for epoch in range(epochs):
print(f"Epoch {epoch+1} of {epochs}")
train_epoch_loss = train(
model, train_loader, optimizer, criterion, train_data, device
)
valid_epoch_loss = validate(
model, valid_loader, criterion, valid_data, device
)
train_loss.append(train_epoch_loss)
valid_loss.append(valid_epoch_loss)
print(f"Train Loss: {train_epoch_loss:.4f}")
print(f'Val Loss: {valid_epoch_loss:.4f}')
# save the trained model to disk
torch.save({
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': criterion,
}, '../outputs/model.pth')
# plot and save the train and validation line graphs
plt.figure(figsize=(10, 7))
plt.plot(train_loss, color='orange', label='train loss')
plt.plot(valid_loss, color='red', label='validataion loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('../outputs/loss.png')
plt.show()