File size: 1,123 Bytes
237774d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from datasets import load_dataset


class ImagenDataset(Dataset):
    def __init__(self, dt, transform, codigo_etiquetas):
        self.dt = dt
        self.tr = transform
        self.codigo = codigo_etiquetas

    def __len__(self):
        return len(self.dt)

    def __getitem__(self, idx):
        row = self.dt[idx]
        imagen = row["image"].convert("RGB")
        label = row["etiqueta"].lower()
        label = self.codigo[label]
        imagen = self.tr(imagen)
        return imagen, label


def cargar_dataset(codigo_etiquetas):
    key = os.environ.get("HFKEY")
    dataset = load_dataset(
        "minoruskore/elementosparaevaluarclases", split="train", token=key
    )

    tr = transforms.Compose([transforms.Resize([256, 256]), transforms.ToTensor()])
    test_dataset = ImagenDataset(
        dataset, transform=tr, codigo_etiquetas=codigo_etiquetas
    )
    cpus = os.cpu_count()
    test_dataloader = DataLoader(test_dataset, batch_size=500, num_workers=cpus)

    return test_dataloader