Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +9 -2
  3. components/query_neighbor.py +75 -0
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  .venv/
2
  __pycache__/
 
 
1
  .venv/
2
  __pycache__/
3
+ .gradio/
app.py CHANGED
@@ -14,6 +14,7 @@ from torchvision import transforms
14
 
15
  from templates import openai_imagenet_template
16
  from components.query import get_sample
 
17
 
18
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
19
  logging.basicConfig(level=logging.INFO, format=log_format)
@@ -90,6 +91,8 @@ zero_shot_examples = [
90
  ],
91
  ]
92
 
 
 
93
 
94
  def indexed(lst, indices):
95
  return [lst[i] for i in indices]
@@ -146,6 +149,10 @@ def open_domain_classification(img, rank: int, return_all=False):
146
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
147
  probs = F.softmax(logits, dim=0)
148
 
 
 
 
 
149
  if rank + 1 == len(ranks):
150
  topk = probs.topk(k)
151
  prediction_dict = {
@@ -154,9 +161,9 @@ def open_domain_classification(img, rank: int, return_all=False):
154
  logger.info(f"Top K predictions: {prediction_dict}")
155
  top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
156
  logger.info(f"Top prediction name: {top_prediction_name}")
157
- sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
158
  if return_all:
159
- return prediction_dict, sample_img, taxon_url
160
  return prediction_dict
161
 
162
  output = collections.defaultdict(float)
 
14
 
15
  from templates import openai_imagenet_template
16
  from components.query import get_sample
17
+ from components.query_neighbor import QueryNeighbor
18
 
19
  log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
20
  logging.basicConfig(level=logging.INFO, format=log_format)
 
91
  ],
92
  ]
93
 
94
+ query_neighbor = QueryNeighbor(dataset_name = "BIRD")
95
+
96
 
97
  def indexed(lst, indices):
98
  return [lst[i] for i in indices]
 
149
  logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
150
  probs = F.softmax(logits, dim=0)
151
 
152
+ neighbor = str(query_neighbor.get_nearest_neighbor(img_features))
153
+ neighbor_image = query_neighbor.get_image(neighbor)
154
+ logger.info(f"Nearest neighbor: {neighbor}")
155
+
156
  if rank + 1 == len(ranks):
157
  topk = probs.topk(k)
158
  prediction_dict = {
 
161
  logger.info(f"Top K predictions: {prediction_dict}")
162
  top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
163
  logger.info(f"Top prediction name: {top_prediction_name}")
164
+ _, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
165
  if return_all:
166
+ return prediction_dict, neighbor_image, taxon_url
167
  return prediction_dict
168
 
169
  output = collections.defaultdict(float)
components/query_neighbor.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import chromadb
4
+ import boto3
5
+ import requests
6
+ import logging
7
+
8
+ from PIL import Image
9
+ from huggingface_hub import snapshot_download
10
+ from dataclasses import dataclass
11
+
12
+ log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
13
+ logging.basicConfig(level=logging.INFO, format=log_format)
14
+ logger = logging.getLogger()
15
+
16
+ S3_BUCKET = "tol-bird-dataset-test"
17
+
18
+ @dataclass
19
+ class VectorDataset:
20
+ dataset_name: str
21
+ hf_dataset_path: str
22
+ relative_vector_db_path: str
23
+
24
+ _SUPPORTED_DATASETS = {
25
+ "BIRD": VectorDataset(
26
+ dataset_name="BIRD",
27
+ hf_dataset_path="imageomics/bird-dataset-vector",
28
+ relative_vector_db_path="bird_vector_db"
29
+ ),
30
+ }
31
+
32
+
33
+ class QueryNeighbor:
34
+ """
35
+ Class to query the nearest neighbor for a given image feature vector.
36
+ It uses a vector database to find the nearest neighbor and retrieves the image from S3.
37
+ The class is initialized with the vector database path and the dataset name.
38
+ The vector database is downloaded from Hugging Face Hub and stored in a local cache.
39
+ The class uses the chromadb library to interact with the vector database and boto3 to interact with S3.
40
+ """
41
+ def __init__(self, dataset_name: str):
42
+ logger.info("Initializing QueryNeighbor")
43
+ vector_dataset = _SUPPORTED_DATASETS.get(dataset_name)
44
+ if vector_dataset is None:
45
+ raise ValueError(f"Unsupported dataset: {dataset_name}")
46
+
47
+ vector_db_path = snapshot_download(
48
+ repo_id=vector_dataset.hf_dataset_path,
49
+ repo_type="dataset"
50
+ )
51
+ logger.info(f"Vector DB cache: {vector_db_path}")
52
+ self._client = chromadb.PersistentClient(
53
+ path=os.path.join(vector_db_path,
54
+ vector_dataset.relative_vector_db_path))
55
+ self._collection = self._client.get_collection(
56
+ name=dataset_name
57
+ )
58
+ self._s3_client = boto3.client("s3")
59
+
60
+
61
+ def get_nearest_neighbor(self, img_features) -> int:
62
+ ''' Returns the nearest neighbors for the given image features. '''
63
+ neighbors = self._collection.query(query_embeddings=[img_features[0].tolist()],
64
+ n_results = 2)
65
+ return neighbors["ids"][0][0]
66
+
67
+ def get_image(self, image_key: str):
68
+ ''' Returns the image for the given key. '''
69
+ img_src = self._s3_client.generate_presigned_url('get_object',
70
+ Params={'Bucket': S3_BUCKET,
71
+ 'Key': image_key}
72
+ )
73
+ img_resp = requests.get(img_src)
74
+ img = Image.open(io.BytesIO(img_resp.content))
75
+ return img