Spaces:
Runtime error
Runtime error
Update `load_image_embeddings`
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ from tokenizers import Tokenizer
|
|
8 |
from torch.utils.data import Dataset
|
9 |
import albumentations as A
|
10 |
from tqdm import tqdm
|
|
|
11 |
from datasets import load_dataset
|
12 |
from fourm.vq.vqvae import VQVAE
|
13 |
from fourm.models.fm import FM
|
@@ -80,14 +81,17 @@ class HuggingFaceImageDataset(Dataset):
|
|
80 |
# Usage
|
81 |
dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo")
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
91 |
print(image_embeddings.shape)
|
92 |
|
93 |
def get_similar_images(caption, brightness, num_items):
|
|
|
8 |
from torch.utils.data import Dataset
|
9 |
import albumentations as A
|
10 |
from tqdm import tqdm
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
from datasets import load_dataset
|
13 |
from fourm.vq.vqvae import VQVAE
|
14 |
from fourm.models.fm import FM
|
|
|
81 |
# Usage
|
82 |
dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo")
|
83 |
|
84 |
+
def load_image_embeddings():
|
85 |
+
# Download the file
|
86 |
+
file_path = hf_hub_download(repo_id="aroraaman/img-tensor", filename="image_emb.pt")
|
87 |
+
|
88 |
+
# Load the tensor
|
89 |
+
image_embeddings = torch.load(file_path)
|
90 |
+
return image_embeddings
|
91 |
+
|
92 |
+
# Use the embeddings in your app
|
93 |
+
image_embeddings = load_image_embeddings()
|
94 |
+
image_embeddings.shape
|
95 |
print(image_embeddings.shape)
|
96 |
|
97 |
def get_similar_images(caption, brightness, num_items):
|