ImageCaptionGenerator / data /dataLoader.py
VenkateshRoshan
Training script is now working
3138612
import os
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
class ImageCaptionDataset(Dataset):
"""
Custom PyTorch Dataset class to handle loading and transforming image-caption pairs
where image paths and captions are provided in a CSV file.
Attributes:
caption_file (str): Path to the CSV file containing image paths and captions.
transform (torchvision.transforms.Compose): Transformations to apply on the images.
"""
def __init__(self, caption_file: str, file_path: str, transform=None):
"""
Initialize dataset with caption CSV file and optional transform.
Args:
caption_file (str): Path to the CSV file where each row has an image path and caption.
transform (callable, optional): Optional transform to apply on an image.
"""
self.df = pd.read_csv(caption_file)
self.image_path = file_path
self.transform = transform or transforms.Compose([
transforms.Resize((224, 224)), # Resize to 224x224 for ViT
transforms.ToTensor(), # Convert to tensor
# Normalize to have values in the range [0, 1]
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def __len__(self):
"""
Return the total number of samples in the dataset.
"""
return len(self.df)
def __getitem__(self, idx):
"""
Retrieve an image and its corresponding caption by index.
Args:
idx (int): Index of the data item.
Returns:
tuple: (image, caption) where image is the transformed image tensor and caption is the associated text.
"""
img_path = self.df.iloc[idx, 0] # The first column contains image paths
caption = self.df.iloc[idx, 1] # The second column contains captions
# Load image
image = Image.open(self.image_path+img_path).convert('RGB')
# Apply transformations to the image
if self.transform:
image = self.transform(image)
return image, caption