Spaces:
Runtime error
Runtime error
File size: 2,953 Bytes
7cfb866 b422fae 7cfb866 2ad4686 7cfb866 3f0ac21 7cfb866 b4abc27 7cfb866 b4abc27 7cfb866 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import torch
import numpy as np
import pandas as pd
import gradio as gr
from io import BytesIO
from PIL import Image as PILIMAGE
#from IPython.display import Image
#from IPython.core.display import HTML
from transformers import CLIPProcessor, CLIPModel, CLIPTokenizer
import os
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("vesteinn/clip-nabirds").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
def load_class_names(dataset_path=''):
names = {}
with open(os.path.join(dataset_path, 'classes.txt')) as f:
for line in f:
pieces = line.strip().split()
class_id = pieces[0]
names[class_id] = ' '.join(pieces[1:])
return names
def get_labels():
labels = []
class_names = load_class_names(".")
for _, name in class_names.items():
labels.append(f"This is a photo of {name}.")
return labels
def encode_text(text):
with torch.no_grad():
inputs = tokenizer([text], padding=True, return_tensors="pt")
text_encoded = model.get_text_features(**inputs).detach().numpy()
return text_encoded
ALL_LABELS = get_labels()
try:
LABEL_FEATURES = np.load("label_features.np")
except:
LABEL_FEATURES = []
for label in ALL_LABELS:
LABEL_FEATURES.append(encode_text(label))
LABEL_FEATURES = np.vstack(LABEL_FEATURES)
np.save(open("label_features.np", "wb"), LABEL_FEATURES)
def encode_image(image):
image = PILIMAGE.fromarray(image.astype('uint8'), 'RGB')
with torch.no_grad():
photo_preprocessed = processor(text=None, images=image, return_tensors="pt", padding=True)["pixel_values"]
search_photo_feature = model.get_image_features(photo_preprocessed.to(device))
search_photo_feature /= search_photo_feature.norm(dim=-1, keepdim=True)
image_encoded = search_photo_feature.cpu().numpy()
return image_encoded
def similarity(feature, label_features):
similarities = list((feature @ label_features.T).squeeze(0))
return similarities
def find_best_matches(image):
image_features = encode_image(image)
similarities = similarity(image_features, LABEL_FEATURES)
best_spec = sorted(zip(similarities, range(LABEL_FEATURES.shape[0])), key=lambda x: x[0], reverse=True)
idx = best_spec[0][1]
label = ALL_LABELS[idx]
return label
examples=[['bj.jpg'],['duckly.jpg'],['some.jpg'],['turdus.jpg'],['seag.jpg'],['thursh.jpg'], ['woodcock.jpeg'],['dipper.jpeg']]
gr.Interface(fn=find_best_matches,
inputs=[
gr.inputs.Image(label="Image to classify", optional=False),
],
examples=examples,
theme="grass",
outputs=gr.outputs.Label(), enable_queue=True, title="North American Bird Classifier",
description="This application can classify North American Birds.").launch()
|