Spaces:
Runtime error
Runtime error
Update dataset
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from tokenizers import Tokenizer
|
|
8 |
from torch.utils.data import Dataset
|
9 |
import albumentations as A
|
10 |
from tqdm import tqdm
|
11 |
-
|
12 |
from fourm.vq.vqvae import VQVAE
|
13 |
from fourm.models.fm import FM
|
14 |
from fourm.models.generate import (
|
@@ -28,7 +28,7 @@ IMG_SIZE = 224
|
|
28 |
TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
|
29 |
FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
|
30 |
VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
|
31 |
-
IMAGE_DATASET_PATH = "
|
32 |
|
33 |
# Load models
|
34 |
text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
@@ -61,25 +61,24 @@ schedule = build_chained_generation_schedules(
|
|
61 |
sampler = GenerationSampler(fm_model)
|
62 |
|
63 |
|
64 |
-
class
|
65 |
-
def __init__(self,
|
66 |
-
self.
|
67 |
-
self.
|
68 |
-
|
69 |
-
|
70 |
|
71 |
def __len__(self):
|
72 |
-
return len(self.
|
73 |
|
74 |
def __getitem__(self, idx):
|
75 |
-
img =
|
76 |
img = np.array(img)
|
77 |
img = self.tfms(image=img)["image"]
|
78 |
return Image.fromarray(img)
|
79 |
|
80 |
-
|
81 |
-
dataset =
|
82 |
-
|
83 |
|
84 |
@torch.no_grad()
|
85 |
def get_image_embeddings(dataset):
|
|
|
8 |
from torch.utils.data import Dataset
|
9 |
import albumentations as A
|
10 |
from tqdm import tqdm
|
11 |
+
from datasets import load_dataset
|
12 |
from fourm.vq.vqvae import VQVAE
|
13 |
from fourm.models.fm import FM
|
14 |
from fourm.models.generate import (
|
|
|
28 |
TOKENIZER_PATH = "./fourm/utils/tokenizer/trained/text_tokenizer_4m_wordpiece_30k.json"
|
29 |
FM_MODEL_PATH = "EPFL-VILAB/4M-21_L"
|
30 |
VQVAE_PATH = "EPFL-VILAB/4M_tokenizers_DINOv2-B14-global_8k_16_224"
|
31 |
+
IMAGE_DATASET_PATH = "./data"
|
32 |
|
33 |
# Load models
|
34 |
text_tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
|
|
|
61 |
sampler = GenerationSampler(fm_model)
|
62 |
|
63 |
|
64 |
+
class HuggingFaceImageDataset(Dataset):
|
65 |
+
def __init__(self, dataset_name, split="train", img_sz=224):
|
66 |
+
self.dataset = load_dataset(dataset_name, split=split)
|
67 |
+
self.tfms = A.Compose([
|
68 |
+
A.SmallestMaxSize(img_sz)
|
69 |
+
])
|
70 |
|
71 |
def __len__(self):
|
72 |
+
return len(self.dataset)
|
73 |
|
74 |
def __getitem__(self, idx):
|
75 |
+
img = self.dataset[idx]['image']
|
76 |
img = np.array(img)
|
77 |
img = self.tfms(image=img)["image"]
|
78 |
return Image.fromarray(img)
|
79 |
|
80 |
+
# Usage
|
81 |
+
dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo")
|
|
|
82 |
|
83 |
@torch.no_grad()
|
84 |
def get_image_embeddings(dataset):
|