bioclip-demo / components /query_neighbor.py
smenon8's picture
Add ability to fetch a dataset of vectordb from hf hub
50f1a2f
raw
history blame
2.72 kB
import io
import os
import chromadb
import boto3
import requests
import logging
from PIL import Image
from huggingface_hub import snapshot_download
from dataclasses import dataclass
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=log_format)
logger = logging.getLogger()
S3_BUCKET = "tol-bird-dataset-test"
@dataclass
class VectorDataset:
dataset_name: str
hf_dataset_path: str
relative_vector_db_path: str
_SUPPORTED_DATASETS = {
"BIRD": VectorDataset(
dataset_name="BIRD",
hf_dataset_path="imageomics/bird-dataset-vector",
relative_vector_db_path="bird_vector_db"
),
}
class QueryNeighbor:
"""
Class to query the nearest neighbor for a given image feature vector.
It uses a vector database to find the nearest neighbor and retrieves the image from S3.
The class is initialized with the vector database path and the dataset name.
The vector database is downloaded from Hugging Face Hub and stored in a local cache.
The class uses the chromadb library to interact with the vector database and boto3 to interact with S3.
"""
def __init__(self, dataset_name: str):
logger.info("Initializing QueryNeighbor")
vector_dataset = _SUPPORTED_DATASETS.get(dataset_name)
if vector_dataset is None:
raise ValueError(f"Unsupported dataset: {dataset_name}")
vector_db_path = snapshot_download(
repo_id=vector_dataset.hf_dataset_path,
repo_type="dataset"
)
logger.info(f"Vector DB cache: {vector_db_path}")
self._client = chromadb.PersistentClient(
path=os.path.join(vector_db_path,
vector_dataset.relative_vector_db_path))
self._collection = self._client.get_collection(
name=dataset_name
)
self._s3_client = boto3.client("s3")
def get_nearest_neighbor(self, img_features) -> int:
''' Returns the nearest neighbors for the given image features. '''
neighbors = self._collection.query(query_embeddings=[img_features[0].tolist()],
n_results = 2)
return neighbors["ids"][0][0]
def get_image(self, image_key: str):
''' Returns the image for the given key. '''
img_src = self._s3_client.generate_presigned_url('get_object',
Params={'Bucket': S3_BUCKET,
'Key': image_key}
)
img_resp = requests.get(img_src)
img = Image.open(io.BytesIO(img_resp.content))
return img