import torch from tqdm import tqdm import torch.nn as nn import torch.optim as optim import torchvision.transforms as transforms from torch.utils.tensorboard import SummaryWriter # For TensorBoard from utils import save_checkpoint, load_checkpoint, print_examples from dataset import get_loader from model import SeqToSeq from tabulate import tabulate # To tabulate loss and epoch import argparse import json def main(args): transform = transforms.Compose( [ transforms.Resize((356, 356)), transforms.RandomCrop((299, 299)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ] ) train_loader, _ = get_loader( root_folder = args.root_dir, annotation_file = args.csv_file, transform=transform, batch_size = 64, num_workers=2, ) vocab = json.load(open('vocab.json')) torch.backends.cudnn.benchmark = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") load_model = False save_model = True train_CNN = False # Hyperparameters embed_size = args.embed_size hidden_size = args.hidden_size vocab_size = len(vocab['stoi']) num_layers = args.num_layers learning_rate = args.lr num_epochs = args.num_epochs # for tensorboard writer = SummaryWriter(args.log_dir) step = 0 model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers} # initialize model, loss etc model = SeqToSeq(**model_params, device = device).to(device) criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi'][""]) optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Only finetune the CNN for name, param in model.encoder.inception.named_parameters(): if "fc.weight" in name or "fc.bias" in name: param.requires_grad = True else: param.requires_grad = train_CNN #load from a save checkpoint if load_model: step = load_checkpoint(torch.load(args.save_path), model, optimizer) model.train() best_loss, best_epoch = 10, 0 for epoch in range(num_epochs): print_examples(model, device, vocab['itos']) for idx, (imgs, captions) in tqdm( enumerate(train_loader), total=len(train_loader), leave=False): imgs = imgs.to(device) captions = captions.to(device) outputs = model(imgs, captions[:-1]) loss = criterion( outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1) ) writer.add_scalar("Training loss", loss.item(), global_step=step) step += 1 optimizer.zero_grad() loss.backward(loss) optimizer.step() train_loss = loss.item() if train_loss < best_loss: best_loss = train_loss best_epoch = epoch + 1 if save_model: checkpoint = { "model_params": model_params, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict(), "step": step } save_checkpoint(checkpoint, args.save_path) table = [["Loss:", train_loss], ["Step:", step], ["Epoch:", epoch + 1], ["Best Loss:", best_loss], ["Best Epoch:", best_epoch]] print(tabulate(table)) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder') parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file') parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs') parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint') # Model Params parser.add_argument('--batch_size', type = int, default = 64) parser.add_argument('--num_epochs', type = int, default = 100) parser.add_argument('--embed_size', type = int, default=256) parser.add_argument('--hidden_size', type = int, default=512) parser.add_argument('--lr', type = float, default= 0.001) parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers') args = parser.parse_args() main(args)