faceyacc's picture
added load_dataset
7d49d67
raw
history blame
1.24 kB
import transformers
import gradio as gr
import datasets
import torch
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from transformers import ViTFeatureExtractor, ViTForImageClassification
from datasets import load_dataset
dataset = load_dataset('beans', 'full_size')
extractor = AutoFeatureExtractor.from_pretrained("saved_model_files")
model = AutoModelForImageClassification.from_pretrained("saved_model_files")
labels = dataset['train'].features['labels'].names
def classify(im):
features = feature_extractor(im, return_tensors='pt')
logits = model(features["pixel_values"])[-1]
probability = torch.nn.functional.softmax(logits, dim=-1)
probs = probability[0].detach().numpy()
confidences = {label: float(probs[i]) for i, label in enumerate(labels)}
return confidences
description = "Bean leaf health classification wit Google's ViT"
title = "Bean Leaf Health Check"
examples = [["'angular_leaf_spot': 0.9999030828475952, 'bean_rust': 5.320278796716593e-05, 'healthy': 4.378804806037806e-05"]]
gr_interface = gr.Interface(classify, inputs='image', outputs='label', title='Bean Classification', description='Monitor your crops health in easier way')
gr_interface.launch(debug=True)