VenkateshRoshan commited on
Commit
be7ebcc
·
1 Parent(s): f3a635f

Initial Commit

Browse files
ImageCaptioning.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
__init__.py ADDED
File without changes
__pycache__/config.cpython-310.pyc ADDED
Binary file (648 Bytes). View file
 
config/__pycache__/config.cpython-310.pyc ADDED
Binary file (659 Bytes). View file
 
config/config.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ class Config:
3
+ IMAGE_SIZE = (224, 224)
4
+ MAX_SEQ_LEN = 64
5
+ VIT_MODEL = 'google/vit-base-patch16-224-in21k'
6
+ GPT2_MODEL = 'gpt2'
7
+ LEARNING_RATE = 5e-5
8
+ EPOCHS = 10
9
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+ AWS_S3_BUCKET = 'your-s3-bucket-name'
11
+ DATASET_PATH = '../Datasets/Flickr8K/'
data/__pycache__/dataLoader.cpython-310.pyc ADDED
Binary file (1.99 kB). View file
 
data/dataLoader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import cv2
4
+ from PIL import Image
5
+ from torchvision import transforms
6
+ import pandas as pd
7
+
8
+ class dataLoader:
9
+ def __init__(self, path):
10
+ self.path = path
11
+ self.img_path = path + 'images/'
12
+ self.caption_path = path + 'captions.csv'
13
+ self.img_list = os.listdir(self.img_path)
14
+ self.caption_dict = self.get_caption_dict()
15
+ self.transform = transforms.Compose([
16
+ transforms.Resize((224, 224)),
17
+ transforms.ToTensor()
18
+ ])
19
+
20
+ def get_caption_dict(self):
21
+ caption_dict = {}
22
+ df = pd.read_csv(self.caption_path, delimiter=',')
23
+ for i in range(len(df)):
24
+ img_name = df.iloc[i, 0]
25
+ caption = df.iloc[i, 1]
26
+ caption_dict[img_name] = caption
27
+ return caption_dict
28
+
29
+ def get_image(self, img_name):
30
+ img = Image.open(self.img_path + img_name)
31
+ img = self.transform(img)
32
+ return img
33
+
34
+ def get_caption(self, img_name):
35
+ return self.caption_dict[img_name]
36
+
37
+ def get_batch(self, batch_size):
38
+ batch = np.random.choice(self.img_list, batch_size)
39
+ images = []
40
+ captions = []
41
+ for img_name in batch:
42
+ images.append(self.get_image(img_name))
43
+ captions.append(self.get_caption(img_name))
44
+ return images, captions
45
+
46
+ def get_all(self):
47
+ images = []
48
+ captions = []
49
+ for img_name in self.img_list:
50
+ images.append(self.get_image(img_name))
51
+ captions.append(self.get_caption(img_name))
52
+ return images, captions
main.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import cv2
4
+ from PIL import Image
5
+ from matplotlib import pyplot as plt
6
+
7
+ from config.config import Config
8
+ from data.dataLoader import dataLoader
9
+
10
+ if __name__ == '__main__':
11
+ dl = dataLoader(Config.DATASET_PATH)
12
+ images, captions = dl.get_all()
13
+ print('Number of images:', len(images))
14
+ print('Number of captions:', len(captions))