File size: 4,577 Bytes
4e527a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter # For TensorBoard
from utils import save_checkpoint, load_checkpoint, print_examples
from dataset import get_loader
from model import SeqToSeq
from tabulate import tabulate # To tabulate loss and epoch
import argparse
import json

def main(args):
    transform = transforms.Compose(
        [
            transforms.Resize((356, 356)),
            transforms.RandomCrop((299, 299)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    train_loader, _ = get_loader(
        root_folder = args.root_dir,
        annotation_file = args.csv_file,
        transform=transform,
        batch_size = 64,
        num_workers=2,
    )
    vocab = json.load(open('vocab.json'))

    torch.backends.cudnn.benchmark = True
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    load_model = False
    save_model = True
    train_CNN = False

    # Hyperparameters
    embed_size = args.embed_size
    hidden_size = args.hidden_size
    vocab_size = len(vocab['stoi'])
    num_layers = args.num_layers
    learning_rate = args.lr
    num_epochs = args.num_epochs
    # for tensorboard

    
    writer = SummaryWriter(args.log_dir)
    step = 0
    model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
    # initialize model, loss etc
    model = SeqToSeq(**model_params, device = device).to(device)
    criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Only finetune the CNN
    for name, param in model.encoder.inception.named_parameters():
        if "fc.weight" in name or "fc.bias" in name:
            param.requires_grad = True
        else:
            param.requires_grad = train_CNN

    #load from a save checkpoint
    if load_model:
        step = load_checkpoint(torch.load(args.save_path), model, optimizer)

    model.train()
    best_loss, best_epoch = 10, 0
    for epoch in range(num_epochs):
        print_examples(model, device, vocab['itos'])

        for idx, (imgs, captions) in tqdm(
            enumerate(train_loader), total=len(train_loader), leave=False):
            imgs = imgs.to(device)
            captions = captions.to(device)

            outputs = model(imgs, captions[:-1])
            loss = criterion(
                outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
            )

            writer.add_scalar("Training loss", loss.item(), global_step=step)
            step += 1

            optimizer.zero_grad()
            loss.backward(loss)
            optimizer.step()
		
        train_loss = loss.item()
        if train_loss < best_loss:
            best_loss = train_loss
            best_epoch = epoch + 1
            if save_model:
                checkpoint = {
                    "model_params": model_params,
		            "state_dict": model.state_dict(),
		            "optimizer": optimizer.state_dict(),
		            "step": step
		        }
                save_checkpoint(checkpoint, args.save_path)


        table = [["Loss:", train_loss],
				["Step:", step],
                ["Epoch:", epoch + 1],
		 		["Best Loss:", best_loss],
		  		["Best Epoch:", best_epoch]]
        print(tabulate(table))
	
	
if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
    parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
    parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
    parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
    # Model Params
    parser.add_argument('--batch_size', type = int, default = 64)
    parser.add_argument('--num_epochs', type = int, default = 100)
    parser.add_argument('--embed_size', type = int, default=256)
    parser.add_argument('--hidden_size', type = int, default=512)
    parser.add_argument('--lr', type = float, default= 0.001)
    parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')

    args = parser.parse_args()
    
    main(args)