Tort73's picture
Initial commit
a9c8ea3
raw
history blame
485 Bytes
import gradio as gr
from fastai.learner import load_learner
from fastai.vision.all import *
learner_inf = load_learner('banana.pkl')
labels = learner_inf.dls.vocab
def predict(img):
img = PILImage.create(img)
pred, pred_idx, probs = learner_inf.predict(img)
return {label: float(probs[i]) for i, label in enumerate(labels)}
gr.Interface(
fn=predict,
inputs=gr.inputs.Image(shape=(192,192)),
outputs=gr.outputs.Label(num_top_classes=3)
).launch(share=True)