Spaces:
Running
Running
Upload 4 files
Browse files- neuralnet/dataset.py +139 -0
- neuralnet/model.py +71 -0
- neuralnet/train.py +130 -0
- neuralnet/utils.py +42 -0
neuralnet/dataset.py
CHANGED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os # when loading file paths
|
2 |
+
import pandas as pd # for lookup in annotation file
|
3 |
+
import spacy # for tokenizer
|
4 |
+
import torch
|
5 |
+
from torch.nn.utils.rnn import pad_sequence # pad batch
|
6 |
+
from torch.utils.data import DataLoader, Dataset
|
7 |
+
from PIL import Image # Load img
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
import json
|
10 |
+
|
11 |
+
# Download with: python -m spacy download en
|
12 |
+
spacy_eng = spacy.load("en_core_web_sm")
|
13 |
+
|
14 |
+
|
15 |
+
class Vocabulary:
|
16 |
+
def __init__(self, freq_threshold):
|
17 |
+
self.itos = {0: "<PAD>", 1: "<SOS>", 2: "<EOS>", 3: "<UNK>"}
|
18 |
+
self.stoi = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3}
|
19 |
+
self.freq_threshold = freq_threshold
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.stoi)
|
23 |
+
|
24 |
+
@staticmethod
|
25 |
+
def tokenizer_eng(text):
|
26 |
+
return [tok.text.lower() for tok in spacy_eng.tokenizer(text)]
|
27 |
+
|
28 |
+
def build_vocabulary(self, sentence_list):
|
29 |
+
frequencies = {}
|
30 |
+
idx = 4
|
31 |
+
|
32 |
+
for sentence in sentence_list:
|
33 |
+
for word in self.tokenizer_eng(sentence):
|
34 |
+
if word not in frequencies:
|
35 |
+
frequencies[word] = 1
|
36 |
+
|
37 |
+
else:
|
38 |
+
frequencies[word] += 1
|
39 |
+
|
40 |
+
if frequencies[word] == self.freq_threshold:
|
41 |
+
self.stoi[word] = idx
|
42 |
+
self.itos[idx] = word
|
43 |
+
idx += 1
|
44 |
+
|
45 |
+
def numericalize(self, text):
|
46 |
+
tokenized_text = self.tokenizer_eng(text)
|
47 |
+
|
48 |
+
return [
|
49 |
+
self.stoi[token] if token in self.stoi else self.stoi["<UNK>"]
|
50 |
+
for token in tokenized_text
|
51 |
+
]
|
52 |
+
|
53 |
+
|
54 |
+
class FlickrDataset(Dataset):
|
55 |
+
def __init__(self, root_dir, captions_file, transform=None, freq_threshold=5):
|
56 |
+
self.root_dir = root_dir
|
57 |
+
self.df = pd.read_csv(captions_file)
|
58 |
+
self.transform = transform
|
59 |
+
|
60 |
+
# Get img, caption columns
|
61 |
+
self.imgs = self.df["image_name"]
|
62 |
+
self.captions = self.df["comment"]
|
63 |
+
|
64 |
+
# Initialize vocabulary and build vocab
|
65 |
+
self.vocab = Vocabulary(freq_threshold)
|
66 |
+
self.vocab.build_vocabulary(self.captions.tolist())
|
67 |
+
|
68 |
+
def __len__(self):
|
69 |
+
return len(self.df)
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
caption = self.captions[index]
|
73 |
+
img_id = self.imgs[index]
|
74 |
+
img = Image.open(os.path.join(self.root_dir, img_id)).convert("RGB")
|
75 |
+
|
76 |
+
if self.transform is not None:
|
77 |
+
img = self.transform(img)
|
78 |
+
|
79 |
+
numericalized_caption = [self.vocab.stoi["<SOS>"]]
|
80 |
+
numericalized_caption += self.vocab.numericalize(caption)
|
81 |
+
numericalized_caption.append(self.vocab.stoi["<EOS>"])
|
82 |
+
|
83 |
+
return img, torch.tensor(numericalized_caption)
|
84 |
+
|
85 |
+
|
86 |
+
class MyCollate:
|
87 |
+
def __init__(self, pad_idx):
|
88 |
+
self.pad_idx = pad_idx
|
89 |
+
|
90 |
+
def __call__(self, batch):
|
91 |
+
imgs = [item[0].unsqueeze(0) for item in batch]
|
92 |
+
imgs = torch.cat(imgs, dim=0)
|
93 |
+
targets = [item[1] for item in batch]
|
94 |
+
targets = pad_sequence(targets, batch_first=False, padding_value=self.pad_idx)
|
95 |
+
|
96 |
+
return imgs, targets
|
97 |
+
|
98 |
+
|
99 |
+
def get_loader(
|
100 |
+
root_folder,
|
101 |
+
annotation_file,
|
102 |
+
transform,
|
103 |
+
batch_size=64,
|
104 |
+
num_workers=2,
|
105 |
+
shuffle=True,
|
106 |
+
pin_memory=True,
|
107 |
+
):
|
108 |
+
dataset = FlickrDataset(root_folder, annotation_file, transform=transform)
|
109 |
+
|
110 |
+
pad_idx = dataset.vocab.stoi["<PAD>"]
|
111 |
+
|
112 |
+
loader = DataLoader(
|
113 |
+
dataset=dataset,
|
114 |
+
batch_size=batch_size,
|
115 |
+
num_workers=num_workers,
|
116 |
+
shuffle=shuffle,
|
117 |
+
pin_memory=pin_memory,
|
118 |
+
collate_fn=MyCollate(pad_idx=pad_idx),
|
119 |
+
)
|
120 |
+
|
121 |
+
return loader, dataset
|
122 |
+
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
transform = transforms.Compose(
|
126 |
+
[transforms.Resize((224, 224)), transforms.ToTensor(),]
|
127 |
+
)
|
128 |
+
|
129 |
+
loader, dataset = get_loader(
|
130 |
+
"/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/flickr30k_images/", "/home/koushik/vscode/Projects/pytorch/img2text_v1/flickr30k/results.csv", transform=transform
|
131 |
+
)
|
132 |
+
|
133 |
+
for idx, (imgs, captions) in enumerate(loader):
|
134 |
+
print(imgs.shape)
|
135 |
+
print(captions.shape)
|
136 |
+
print(len(dataset.vocab))
|
137 |
+
test = {"itos":dataset.vocab.itos, "stoi": dataset.vocab.stoi}
|
138 |
+
json.dump(test, open('test.json', 'w'))
|
139 |
+
break
|
neuralnet/model.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchvision.models as models
|
4 |
+
|
5 |
+
|
6 |
+
class InceptionEncoder(nn.Module):
|
7 |
+
def __init__(self, embed_size, train_CNN=False):
|
8 |
+
super(InceptionEncoder, self).__init__()
|
9 |
+
self.train_CNN = train_CNN
|
10 |
+
self.inception = models.inception_v3(pretrained=True, aux_logits=False)
|
11 |
+
self.inception.fc = nn.Linear(self.inception.fc.in_features, embed_size)
|
12 |
+
self.relu = nn.ReLU()
|
13 |
+
self.bn = nn.BatchNorm1d(embed_size, momentum = 0.01)
|
14 |
+
self.dropout = nn.Dropout(0.5)
|
15 |
+
|
16 |
+
def forward(self, images):
|
17 |
+
features = self.inception(images)
|
18 |
+
norm_features = self.bn(features)
|
19 |
+
return self.dropout(self.relu(norm_features))
|
20 |
+
|
21 |
+
|
22 |
+
class LstmDecoder(nn.Module):
|
23 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
|
24 |
+
super(LstmDecoder, self).__init__()
|
25 |
+
self.num_layers = num_layers
|
26 |
+
self.hidden_size = hidden_size
|
27 |
+
self.device = device
|
28 |
+
self.embed = nn.Embedding(vocab_size, embed_size)
|
29 |
+
self.lstm = nn.LSTM(embed_size, hidden_size, num_layers = self.num_layers)
|
30 |
+
self.linear = nn.Linear(hidden_size, vocab_size)
|
31 |
+
self.dropout = nn.Dropout(0.5)
|
32 |
+
|
33 |
+
def forward(self, encoder_out, captions):
|
34 |
+
h0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
|
35 |
+
c0 = torch.zeros(self.num_layers, encoder_out.shape[0], self.hidden_size).to(self.device).requires_grad_()
|
36 |
+
embeddings = self.dropout(self.embed(captions))
|
37 |
+
embeddings = torch.cat((encoder_out.unsqueeze(0), embeddings), dim=0)
|
38 |
+
hiddens, (hn, cn) = self.lstm(embeddings, (h0.detach(), c0.detach()))
|
39 |
+
outputs = self.linear(hiddens)
|
40 |
+
return outputs
|
41 |
+
|
42 |
+
|
43 |
+
class SeqToSeq(nn.Module):
|
44 |
+
def __init__(self, embed_size, hidden_size, vocab_size, num_layers, device = 'cpu'):
|
45 |
+
super(SeqToSeq, self).__init__()
|
46 |
+
self.encoder = InceptionEncoder(embed_size)
|
47 |
+
self.decoder = LstmDecoder(embed_size, hidden_size, vocab_size, num_layers, device)
|
48 |
+
|
49 |
+
def forward(self, images, captions):
|
50 |
+
features = self.encoder(images)
|
51 |
+
outputs = self.decoder(features, captions)
|
52 |
+
return outputs
|
53 |
+
|
54 |
+
def caption_image(self, image, vocabulary, max_length = 50):
|
55 |
+
result_caption = []
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
x = self.encoder(image).unsqueeze(0)
|
59 |
+
states = None
|
60 |
+
|
61 |
+
for _ in range(max_length):
|
62 |
+
hiddens, states = self.decoder.lstm(x, states)
|
63 |
+
output = self.decoder.linear(hiddens.squeeze(0))
|
64 |
+
predicted = output.argmax(1)
|
65 |
+
result_caption.append(predicted.item())
|
66 |
+
x = self.decoder.embed(predicted).unsqueeze(0)
|
67 |
+
|
68 |
+
if vocabulary[str(predicted.item())] == "<EOS>":
|
69 |
+
break
|
70 |
+
|
71 |
+
return [vocabulary[str(idx)] for idx in result_caption]
|
neuralnet/train.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from tqdm import tqdm
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.optim as optim
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from torch.utils.tensorboard import SummaryWriter # For TensorBoard
|
7 |
+
from utils import save_checkpoint, load_checkpoint, print_examples
|
8 |
+
from dataset import get_loader
|
9 |
+
from model import SeqToSeq
|
10 |
+
from tabulate import tabulate # To tabulate loss and epoch
|
11 |
+
import argparse
|
12 |
+
import json
|
13 |
+
|
14 |
+
def main(args):
|
15 |
+
transform = transforms.Compose(
|
16 |
+
[
|
17 |
+
transforms.Resize((356, 356)),
|
18 |
+
transforms.RandomCrop((299, 299)),
|
19 |
+
transforms.ToTensor(),
|
20 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
21 |
+
]
|
22 |
+
)
|
23 |
+
|
24 |
+
train_loader, _ = get_loader(
|
25 |
+
root_folder = args.root_dir,
|
26 |
+
annotation_file = args.csv_file,
|
27 |
+
transform=transform,
|
28 |
+
batch_size = 64,
|
29 |
+
num_workers=2,
|
30 |
+
)
|
31 |
+
vocab = json.load(open('vocab.json'))
|
32 |
+
|
33 |
+
torch.backends.cudnn.benchmark = True
|
34 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
load_model = False
|
36 |
+
save_model = True
|
37 |
+
train_CNN = False
|
38 |
+
|
39 |
+
# Hyperparameters
|
40 |
+
embed_size = args.embed_size
|
41 |
+
hidden_size = args.hidden_size
|
42 |
+
vocab_size = len(vocab['stoi'])
|
43 |
+
num_layers = args.num_layers
|
44 |
+
learning_rate = args.lr
|
45 |
+
num_epochs = args.num_epochs
|
46 |
+
# for tensorboard
|
47 |
+
|
48 |
+
|
49 |
+
writer = SummaryWriter(args.log_dir)
|
50 |
+
step = 0
|
51 |
+
model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
|
52 |
+
# initialize model, loss etc
|
53 |
+
model = SeqToSeq(**model_params, device = device).to(device)
|
54 |
+
criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
|
55 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
56 |
+
|
57 |
+
# Only finetune the CNN
|
58 |
+
for name, param in model.encoder.inception.named_parameters():
|
59 |
+
if "fc.weight" in name or "fc.bias" in name:
|
60 |
+
param.requires_grad = True
|
61 |
+
else:
|
62 |
+
param.requires_grad = train_CNN
|
63 |
+
|
64 |
+
#load from a save checkpoint
|
65 |
+
if load_model:
|
66 |
+
step = load_checkpoint(torch.load(args.save_path), model, optimizer)
|
67 |
+
|
68 |
+
model.train()
|
69 |
+
best_loss, best_epoch = 10, 0
|
70 |
+
for epoch in range(num_epochs):
|
71 |
+
print_examples(model, device, vocab['itos'])
|
72 |
+
|
73 |
+
for idx, (imgs, captions) in tqdm(
|
74 |
+
enumerate(train_loader), total=len(train_loader), leave=False):
|
75 |
+
imgs = imgs.to(device)
|
76 |
+
captions = captions.to(device)
|
77 |
+
|
78 |
+
outputs = model(imgs, captions[:-1])
|
79 |
+
loss = criterion(
|
80 |
+
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
|
81 |
+
)
|
82 |
+
|
83 |
+
writer.add_scalar("Training loss", loss.item(), global_step=step)
|
84 |
+
step += 1
|
85 |
+
|
86 |
+
optimizer.zero_grad()
|
87 |
+
loss.backward(loss)
|
88 |
+
optimizer.step()
|
89 |
+
|
90 |
+
train_loss = loss.item()
|
91 |
+
if train_loss < best_loss:
|
92 |
+
best_loss = train_loss
|
93 |
+
best_epoch = epoch + 1
|
94 |
+
if save_model:
|
95 |
+
checkpoint = {
|
96 |
+
"model_params": model_params,
|
97 |
+
"state_dict": model.state_dict(),
|
98 |
+
"optimizer": optimizer.state_dict(),
|
99 |
+
"step": step
|
100 |
+
}
|
101 |
+
save_checkpoint(checkpoint, args.save_path)
|
102 |
+
|
103 |
+
|
104 |
+
table = [["Loss:", train_loss],
|
105 |
+
["Step:", step],
|
106 |
+
["Epoch:", epoch + 1],
|
107 |
+
["Best Loss:", best_loss],
|
108 |
+
["Best Epoch:", best_epoch]]
|
109 |
+
print(tabulate(table))
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
|
114 |
+
parser = argparse.ArgumentParser()
|
115 |
+
|
116 |
+
parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
|
117 |
+
parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
|
118 |
+
parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
|
119 |
+
parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
|
120 |
+
# Model Params
|
121 |
+
parser.add_argument('--batch_size', type = int, default = 64)
|
122 |
+
parser.add_argument('--num_epochs', type = int, default = 100)
|
123 |
+
parser.add_argument('--embed_size', type = int, default=256)
|
124 |
+
parser.add_argument('--hidden_size', type = int, default=512)
|
125 |
+
parser.add_argument('--lr', type = float, default= 0.001)
|
126 |
+
parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')
|
127 |
+
|
128 |
+
args = parser.parse_args()
|
129 |
+
|
130 |
+
main(args)
|
neuralnet/utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
|
6 |
+
def print_examples(model, device, vocab):
|
7 |
+
transform = transforms.Compose(
|
8 |
+
[transforms.Resize((299, 299)),
|
9 |
+
transforms.ToTensor(),
|
10 |
+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
11 |
+
)
|
12 |
+
|
13 |
+
model.eval()
|
14 |
+
|
15 |
+
test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0)
|
16 |
+
print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab)))
|
17 |
+
|
18 |
+
test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0)
|
19 |
+
print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab)))
|
20 |
+
|
21 |
+
test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0)
|
22 |
+
print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab)))
|
23 |
+
|
24 |
+
test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0)
|
25 |
+
print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab)))
|
26 |
+
|
27 |
+
test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0)
|
28 |
+
print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab)))
|
29 |
+
model.train()
|
30 |
+
|
31 |
+
|
32 |
+
def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"):
|
33 |
+
print("=> Saving checkpoint")
|
34 |
+
torch.save(state, filename)
|
35 |
+
|
36 |
+
|
37 |
+
def load_checkpoint(checkpoint, model, optimizer):
|
38 |
+
print("=> Loading checkpoint")
|
39 |
+
model.load_state_dict(checkpoint["state_dict"])
|
40 |
+
optimizer.load_state_dict(checkpoint["optimizer"])
|
41 |
+
step = checkpoint["step"]
|
42 |
+
return step
|