import io import os import chromadb import boto3 import requests import logging from PIL import Image from huggingface_hub import snapshot_download from dataclasses import dataclass log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" logging.basicConfig(level=logging.INFO, format=log_format) logger = logging.getLogger() S3_BUCKET = "tol-bird-dataset-test" @dataclass class VectorDataset: dataset_name: str hf_dataset_path: str relative_vector_db_path: str _SUPPORTED_DATASETS = { "BIRD": VectorDataset( dataset_name="BIRD", hf_dataset_path="imageomics/bird-dataset-vector", relative_vector_db_path="bird_vector_db" ), } class QueryNeighbor: """ Class to query the nearest neighbor for a given image feature vector. It uses a vector database to find the nearest neighbor and retrieves the image from S3. The class is initialized with the vector database path and the dataset name. The vector database is downloaded from Hugging Face Hub and stored in a local cache. The class uses the chromadb library to interact with the vector database and boto3 to interact with S3. """ def __init__(self, dataset_name: str): logger.info("Initializing QueryNeighbor") vector_dataset = _SUPPORTED_DATASETS.get(dataset_name) if vector_dataset is None: raise ValueError(f"Unsupported dataset: {dataset_name}") vector_db_path = snapshot_download( repo_id=vector_dataset.hf_dataset_path, repo_type="dataset" ) logger.info(f"Vector DB cache: {vector_db_path}") self._client = chromadb.PersistentClient( path=os.path.join(vector_db_path, vector_dataset.relative_vector_db_path)) self._collection = self._client.get_collection( name=dataset_name ) self._s3_client = boto3.client("s3") def get_nearest_neighbor(self, img_features) -> int: ''' Returns the nearest neighbors for the given image features. ''' neighbors = self._collection.query(query_embeddings=[img_features[0].tolist()], n_results = 2) return neighbors["ids"][0][0] def get_image(self, image_key: str): ''' Returns the image for the given key. ''' img_src = self._s3_client.generate_presigned_url('get_object', Params={'Bucket': S3_BUCKET, 'Key': image_key} ) img_resp = requests.get(img_src) img = Image.open(io.BytesIO(img_resp.content)) return img