Spaces:
Running
Running
import io | |
import os | |
import chromadb | |
import boto3 | |
import requests | |
import logging | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
from dataclasses import dataclass | |
log_format = "[%(asctime)s] [%(levelname)s] [%(name)s] %(message)s" | |
logging.basicConfig(level=logging.INFO, format=log_format) | |
logger = logging.getLogger() | |
S3_BUCKET = "tol-bird-dataset-test" | |
class VectorDataset: | |
dataset_name: str | |
hf_dataset_path: str | |
relative_vector_db_path: str | |
_SUPPORTED_DATASETS = { | |
"BIRD": VectorDataset( | |
dataset_name="BIRD", | |
hf_dataset_path="imageomics/bird-dataset-vector", | |
relative_vector_db_path="bird_vector_db" | |
), | |
} | |
class QueryNeighbor: | |
""" | |
Class to query the nearest neighbor for a given image feature vector. | |
It uses a vector database to find the nearest neighbor and retrieves the image from S3. | |
The class is initialized with the vector database path and the dataset name. | |
The vector database is downloaded from Hugging Face Hub and stored in a local cache. | |
The class uses the chromadb library to interact with the vector database and boto3 to interact with S3. | |
""" | |
def __init__(self, dataset_name: str): | |
logger.info("Initializing QueryNeighbor") | |
vector_dataset = _SUPPORTED_DATASETS.get(dataset_name) | |
if vector_dataset is None: | |
raise ValueError(f"Unsupported dataset: {dataset_name}") | |
vector_db_path = snapshot_download( | |
repo_id=vector_dataset.hf_dataset_path, | |
repo_type="dataset" | |
) | |
logger.info(f"Vector DB cache: {vector_db_path}") | |
self._client = chromadb.PersistentClient( | |
path=os.path.join(vector_db_path, | |
vector_dataset.relative_vector_db_path)) | |
self._collection = self._client.get_collection( | |
name=dataset_name | |
) | |
self._s3_client = boto3.client("s3") | |
def get_nearest_neighbor(self, img_features) -> int: | |
''' Returns the nearest neighbors for the given image features. ''' | |
neighbors = self._collection.query(query_embeddings=[img_features[0].tolist()], | |
n_results = 2) | |
return neighbors["ids"][0][0] | |
def get_image(self, image_key: str): | |
''' Returns the image for the given key. ''' | |
img_src = self._s3_client.generate_presigned_url('get_object', | |
Params={'Bucket': S3_BUCKET, | |
'Key': image_key} | |
) | |
img_resp = requests.get(img_src) | |
img = Image.open(io.BytesIO(img_resp.content)) | |
return img |