Spaces:
Sleeping
Sleeping
from datasets import load_dataset | |
import torch | |
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
from loadimg import load_img | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' # we should rlly check for mps but, who uses macs (this is a space. lol) | |
processor = AutoProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
model = AutoModelForZeroShotImageClassification.from_pretrained("openai/clip-vit-large-patch14", device_map = device) | |
class Instance: | |
def __init__(self, dataset, token=None, split="train"): | |
self.dataset = dataset | |
self.token = token | |
self.split = split | |
self.data = load_dataset(self.dataset, split=self.split) | |
self.data = self.data.add_faiss_index("embeddings") | |
def embed(batch): | |
"""a function that embeds a batch of images and returns the embeddings intended for embedding already existing images in an external dataset. (unused)""" | |
pixel_values = processor(images = batch["image"], return_tensors="pt")['pixel_values'] | |
pixel_values = pixel_values.to(device) | |
img_emb = model.get_image_features(pixel_values) | |
batch["embeddings"] = img_emb | |
return batch | |
def search(self, query: str, k: int = 3 ): | |
""" | |
A function that embeds a query image and returns the most probable results. | |
Args: | |
query: the image to search for | |
k: the number of results to return | |
Returns: | |
scores: the scores of the retrieved examples (cosine similarity i think in this case) | |
retrieved_examples: the retrieved examples | |
""" | |
pixel_values = processor(images = query, return_tensors="pt")['pixel_values'] | |
pixel_values = pixel_values.to(device) | |
img_emb = model.get_image_features(pixel_values)[0] | |
img_emb = img_emb.cpu().detach().numpy() | |
scores, retrieved_examples = self.data.get_nearest_examples( | |
"embeddings", img_emb, | |
k=k | |
) | |
return scores, retrieved_examples | |
def high_level_search(self, img): | |
""" | |
High level wrapper for the search function. | |
Args: | |
img: input image (path, url, pillow or numpy) | |
Returns: | |
scores: the scores of the retrieved examples (cosine similarity i think in this case) | |
retrieved_examples: the retrieved examples | |
""" | |
image = load_img(img) | |
scores, retrieved_examples = self.search(image) |