import os import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from datasets import load_dataset class ImagenDataset(Dataset): def __init__(self, dt, transform, codigo_etiquetas): self.dt = dt self.tr = transform self.codigo = codigo_etiquetas def __len__(self): return len(self.dt) def __getitem__(self, idx): row = self.dt[idx] imagen = row["image"].convert("RGB") label = row["etiqueta"].lower() label = self.codigo[label] imagen = self.tr(imagen) return imagen, label def cargar_dataset(codigo_etiquetas): key = os.environ.get("HFKEY") dataset = load_dataset( "minoruskore/elementosparaevaluarclases", split="train", token=key ) tr = transforms.Compose([transforms.Resize([256, 256]), transforms.ToTensor()]) test_dataset = ImagenDataset( dataset, transform=tr, codigo_etiquetas=codigo_etiquetas ) cpus = os.cpu_count() test_dataloader = DataLoader(test_dataset, batch_size=500, num_workers=cpus) return test_dataloader