Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
from PIL import Image | |
from torchvision import transforms | |
from torch.utils.data import Dataset | |
class ImageCaptionDataset(Dataset): | |
""" | |
Custom PyTorch Dataset class to handle loading and transforming image-caption pairs | |
where image paths and captions are provided in a CSV file. | |
Attributes: | |
caption_file (str): Path to the CSV file containing image paths and captions. | |
transform (torchvision.transforms.Compose): Transformations to apply on the images. | |
""" | |
def __init__(self, caption_file: str, file_path: str, transform=None): | |
""" | |
Initialize dataset with caption CSV file and optional transform. | |
Args: | |
caption_file (str): Path to the CSV file where each row has an image path and caption. | |
transform (callable, optional): Optional transform to apply on an image. | |
""" | |
self.df = pd.read_csv(caption_file) | |
self.image_path = file_path | |
self.transform = transform or transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to 224x224 for ViT | |
transforms.ToTensor(), # Convert to tensor | |
# Normalize to have values in the range [0, 1] | |
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
def __len__(self): | |
""" | |
Return the total number of samples in the dataset. | |
""" | |
return len(self.df) | |
def __getitem__(self, idx): | |
""" | |
Retrieve an image and its corresponding caption by index. | |
Args: | |
idx (int): Index of the data item. | |
Returns: | |
tuple: (image, caption) where image is the transformed image tensor and caption is the associated text. | |
""" | |
img_path = self.df.iloc[idx, 0] # The first column contains image paths | |
caption = self.df.iloc[idx, 1] # The second column contains captions | |
# Load image | |
image = Image.open(self.image_path+img_path).convert('RGB') | |
# Apply transformations to the image | |
if self.transform: | |
image = self.transform(image) | |
return image, caption | |