Spaces:
Runtime error
Runtime error
import models | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
import matplotlib | |
from engine import train, validate | |
from dataset import ImageDataset | |
from torch.utils.data import DataLoader | |
matplotlib.style.use('ggplot') | |
# initialize the computation device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(device) | |
#intialize the model | |
model = models.model(pretrained=True, requires_grad=False).to(device) | |
# learning parameters | |
lr = 0.0001 | |
epochs = 10 | |
batch_size = 32 | |
optimizer = optim.Adam(model.parameters(), lr=lr) | |
criterion = nn.BCELoss() | |
# read the training csv file | |
train_csv = pd.read_csv('../input/movie-classifier/Multi_Label_dataset/train.csv') | |
# train dataset | |
train_data = ImageDataset( | |
train_csv, train=True, test=False | |
) | |
# validation dataset | |
valid_data = ImageDataset( | |
train_csv, train=False, test=False | |
) | |
# train data loader | |
train_loader = DataLoader( | |
train_data, | |
batch_size=batch_size, | |
shuffle=True | |
) | |
# validation data loader | |
valid_loader = DataLoader( | |
valid_data, | |
batch_size=batch_size, | |
shuffle=False | |
) | |
# start the training and validation | |
train_loss = [] | |
valid_loss = [] | |
for epoch in range(epochs): | |
print(f"Epoch {epoch+1} of {epochs}") | |
train_epoch_loss = train( | |
model, train_loader, optimizer, criterion, train_data, device | |
) | |
valid_epoch_loss = validate( | |
model, valid_loader, criterion, valid_data, device | |
) | |
train_loss.append(train_epoch_loss) | |
valid_loss.append(valid_epoch_loss) | |
print(f"Train Loss: {train_epoch_loss:.4f}") | |
print(f'Val Loss: {valid_epoch_loss:.4f}') | |
# save the trained model to disk | |
torch.save({ | |
'epoch': epochs, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'loss': criterion, | |
}, '../outputs/model.pth') | |
# plot and save the train and validation line graphs | |
plt.figure(figsize=(10, 7)) | |
plt.plot(train_loss, color='orange', label='train loss') | |
plt.plot(valid_loss, color='red', label='validataion loss') | |
plt.xlabel('Epochs') | |
plt.ylabel('Loss') | |
plt.legend() | |
plt.savefig('../outputs/loss.png') | |
plt.show() |