image-retriever / utils /dataset_rag.py
npbm's picture
i cant use git for the life of me. might need more testing
7dc7c5c verified
raw
history blame
2.6 kB
from datasets import load_dataset
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from loadimg import load_img
device = 'cuda' if torch.cuda.is_available() else 'cpu' # we should rlly check for mps but, who uses macs (this is a space. lol)
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14")
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device)
class Instance:
def __init__(self, dataset, token=None, split="train"):
self.dataset = dataset
self.token = token
self.split = split
self.data = load_dataset(self.dataset, split=self.split)
self.data = self.data.add_faiss_index("embeddings")
def embed(batch):
"""a function that embeds a batch of images and returns the embeddings intended for embedding already existing images in an external dataset. (unused)"""
pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values']
pixel_values = pixel_values.to(device)
img_emb = model.get_image_features(pixel_values)
batch["embeddings"] = img_emb
return batch
def search(self, query: str, k: int = 3 ):
"""
A function that embeds a query image and returns the most probable results.
Args:
query: the image to search for
k: the number of results to return
Returns:
scores: the scores of the retrieved examples (cosine similarity i think in this case)
retrieved_examples: the retrieved examples
"""
pixel_values = processor(images = query, return_tensors="pt")['pixel_values']
pixel_values = pixel_values.to(device)
img_emb = model.get_image_features(pixel_values)[0]
img_emb = img_emb.cpu().detach().numpy()
scores, retrieved_examples = self.data.get_nearest_examples(
"embeddings", img_emb,
k=k
)
return scores, retrieved_examples
def high_level_search(self, img):
"""
High level wrapper for the search function.
Args:
img: input image (path, url, pillow or numpy)
Returns:
scores: the scores of the retrieved examples (cosine similarity i think in this case)
retrieved_examples: the retrieved examples
"""
image = load_img(img)
scores, retrieved_examples = self.search(image)