aroraaman commited on
Commit
fc9fb44
1 Parent(s): 31b8277

Update `load_image_embeddings`

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -8,6 +8,7 @@ from tokenizers import Tokenizer
8
  from torch.utils.data import Dataset
9
  import albumentations as A
10
  from tqdm import tqdm
 
11
  from datasets import load_dataset
12
  from fourm.vq.vqvae import VQVAE
13
  from fourm.models.fm import FM
@@ -80,14 +81,17 @@ class HuggingFaceImageDataset(Dataset):
80
  # Usage
81
  dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo")
82
 
83
- @torch.no_grad()
84
- def get_image_embeddings(dataset):
85
- cache_file = "image_emb.pt"
86
- if os.path.exists(cache_file):
87
- return torch.load(cache_file)
88
-
89
-
90
- image_embeddings = get_image_embeddings(dataset).to(DEVICE)
 
 
 
91
  print(image_embeddings.shape)
92
 
93
  def get_similar_images(caption, brightness, num_items):
 
8
  from torch.utils.data import Dataset
9
  import albumentations as A
10
  from tqdm import tqdm
11
+ from huggingface_hub import hf_hub_download
12
  from datasets import load_dataset
13
  from fourm.vq.vqvae import VQVAE
14
  from fourm.models.fm import FM
 
81
  # Usage
82
  dataset = HuggingFaceImageDataset("aroraaman/4m-21-demo")
83
 
84
+ def load_image_embeddings():
85
+ # Download the file
86
+ file_path = hf_hub_download(repo_id="aroraaman/img-tensor", filename="image_emb.pt")
87
+
88
+ # Load the tensor
89
+ image_embeddings = torch.load(file_path)
90
+ return image_embeddings
91
+
92
+ # Use the embeddings in your app
93
+ image_embeddings = load_image_embeddings()
94
+ image_embeddings.shape
95
  print(image_embeddings.shape)
96
 
97
  def get_similar_images(caption, brightness, num_items):