aroraaman commited on
Commit
31b8277
·
1 Parent(s): 3424266

Update dataset

Browse files
Files changed (1) hide show
  1. app.py +12 -13
app.py CHANGED
@@ -8,7 +8,7 @@ from tokenizers import Tokenizer
8
  from torch.utils.data import Dataset
9
  import albumentations as A
10
  from tqdm import tqdm
11
-
12
  from fourm.vq.vqvae import VQVAE
13
  from fourm.models.fm import FM
14
  from fourm.models.generate import (
@@ -28,7 +28,7 @@ IMG_SIZE = 224
28
  TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
29
  FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
30
  VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
31
- IMAGE_DATASET_PATH = "/home/ubuntu/GIT_REPOS/ml-4m/data/custom_data/"
32
 
33
  # Load models
34
  text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
@@ -61,25 +61,24 @@ schedule = build_chained_generation_schedules(
61
  sampler = GenerationSampler(fm_model)
62
 
63
 
64
- class ImageDataset(Dataset):
65
- def __init__(self, path: str, img_sz=IMG_SIZE):
66
- self.path = Path(path)
67
- self.files = list(self.path.rglob("*"))
68
- self.tfms = A.Compose(
69
- [A.SmallestMaxSize(img_sz)])
70
 
71
  def __len__(self):
72
- return len(self.files)
73
 
74
  def __getitem__(self, idx):
75
- img = Image.open(self.files[idx]).convert("RGB")
76
  img = np.array(img)
77
  img = self.tfms(image=img)["image"]
78
  return Image.fromarray(img)
79
 
80
-
81
- dataset = ImageDataset(IMAGE_DATASET_PATH)
82
-
83
 
84
  @torch.no_grad()
85
  def get_image_embeddings(dataset):
 
8
  from torch.utils.data import Dataset
9
  import albumentations as A
10
  from tqdm import tqdm
11
+ from datasets import load_dataset
12
  from fourm.vq.vqvae import VQVAE
13
  from fourm.models.fm import FM
14
  from fourm.models.generate import (
 
28
  TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
29
  FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
30
  VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
31
+ IMAGE_DATASET_PATH = "./data"
32
 
33
  # Load models
34
  text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
 
61
  sampler = GenerationSampler(fm_model)
62
 
63
 
64
+ class HuggingFaceImageDataset(Dataset):
65
+ def __init__(self, dataset_name, split="train", img_sz=224):
66
+ self.dataset = load_dataset(dataset_name, split=split)
67
+ self.tfms = A.Compose([
68
+ A.SmallestMaxSize(img_sz)
69
+ ])
70
 
71
  def __len__(self):
72
+ return len(self.dataset)
73
 
74
  def __getitem__(self, idx):
75
+ img = self.dataset[idx]['image']
76
  img = np.array(img)
77
  img = self.tfms(image=img)["image"]
78
  return Image.fromarray(img)
79
 
80
+ # Usage
81
+ dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo")
 
82
 
83
  @torch.no_grad()
84
  def get_image_embeddings(dataset):