Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from io import BytesIO | |
from PIL import Image as PILIMAGE | |
#from IPython.display import Image | |
#from IPython.core.display import HTML | |
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer | |
import os | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device) | |
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") | |
def load_class_names(dataset_path=''): | |
names = {} | |
with open(os.path.join(dataset_path, 'classes.txt')) as f: | |
for line in f: | |
pieces = line.strip().split() | |
class_id = pieces[0] | |
names[class_id] = ' '.join(pieces[1:]) | |
return names | |
def get_labels(): | |
labels = [] | |
class_names = load_class_names(".") | |
for _, name in class_names.items(): | |
labels.append(f"This is a photo of {name}.") | |
return labels | |
def encode_text(text): | |
with torch.no_grad(): | |
inputs = tokenizer([text], padding=True, return_tensors="pt") | |
text_encoded = model.get_text_features(**inputs).detach().numpy() | |
return text_encoded | |
ALL_LABELS = get_labels() | |
try: | |
LABEL_FEATURES = np.load("label_features.np") | |
except: | |
LABEL_FEATURES = [] | |
for label in ALL_LABELS: | |
LABEL_FEATURES.append(encode_text(label)) | |
LABEL_FEATURES = np.vstack(LABEL_FEATURES) | |
np.save(open("label_features.np", "wb"), LABEL_FEATURES) | |
def encode_image(image): | |
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB') | |
with torch.no_grad(): | |
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"] | |
search_photo_feature = model.get_image_features(photo_preprocessed.to(device)) | |
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True) | |
image_encoded = search_photo_feature.cpu().numpy() | |
return image_encoded | |
def similarity(feature, label_features): | |
similarities = list((feature @ label_features.T).squeeze(0)) | |
return similarities | |
def find_best_matches(image): | |
image_features = encode_image(image) | |
similarities = similarity(image_features, LABEL_FEATURES) | |
best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True) | |
idx = best_spec[0][1] | |
label = ALL_LABELS[idx] | |
return label | |
examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']] | |
gr.Interface(fn=find_best_matches, | |
inputs=[ | |
gr.inputs.Image(label="Image to classify", optional=False), | |
], | |
examples=examples, | |
theme="grass", | |
outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier", | |
description="This application can classify North American Birds.").launch() | |