Spaces:
Runtime error
Runtime error
import gradio as gr | |
from datasets import load_dataset | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torchvision.transforms import Compose, ColorJitter, ToTensor, RandomPerspective | |
from transformers import AutoFeatureExtractor, AutoModelForImageClassification | |
with open("dogs_labels.txt", "r") as f: | |
labels = f.read().split('\n') | |
num_labels = len(labels) | |
id2label = {str(i): c for i, c in enumerate(labels)} | |
def classify_image(inp): | |
# Load model | |
model = AutoModelForImageClassification.from_pretrained(pretrained_model_name_or_path="asusevski/vit-dog-classifier") | |
# Preprocess | |
model_preprocessor_name = "google/vit-base-patch16-224" | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_preprocessor_name) | |
augs = Compose( | |
[ | |
ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.4), | |
RandomPerspective(distortion_scale=0.3), | |
ToTensor(), | |
] | |
) | |
inp = np.array(augs(inp.convert('RGB'))) | |
inp = torch.tensor(feature_extractor(images=inp)['pixel_values']) | |
preds = model(inp)['logits'] | |
preds = torch.flatten(preds) | |
preds = nn.functional.softmax(preds, dim=0) | |
confidences = {labels[i]: preds[i].item() for i in range(num_labels)} | |
return confidences | |
gr.Interface(fn=classify_image, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Label(num_top_classes=3) | |
).launch() | |