Spaces:
Runtime error
Runtime error
''' | |
Model Gradio UI | |
''' | |
######################################################################### | |
# imports | |
from fastai.vision.all import * | |
import gradio as gr | |
import pathlib | |
from huggingface_hub import hf_hub_download | |
######################################################################### | |
# user access token for HF model library | |
ACCESS_TOKEN = "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" | |
######################################################################### | |
#Consider path seperators for alternate OS | |
plt = platform.system() | |
if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath | |
######################################################################### | |
def import_model(model_name): | |
path = hf_hub_download(repo_id='amandasarubbi/tm-tko-models', | |
filename=model_name, | |
use_auth_token=ACCESS_TOKEN, | |
repo_type='model') | |
learn = load_learner(path, cpu=True) | |
return learn | |
######################################################################### | |
######################################################################### | |
# Function to predict outputs | |
def predict(img, model_name): | |
if (model_name == 'Geometric Figures & Solids'): | |
geo_learn = import_model('geometric_model.pkl') | |
preds = geo_learn.predict(img) | |
elif (model_name == 'Scenery, Natural Phenomena'): | |
landscape_learn = import_model('landscape_model.pkl') | |
preds = landscape_learn.predict(img) | |
elif (model_name == 'Human & Supernatural Beings'): | |
human_learn = import_model('human_model.pkl') | |
preds = human_learn.predict(img) | |
elif (model_name == 'Colors & Characters'): | |
colors_learn = import_model('colors_model.pkl') | |
preds = colors_learn.predict(img) | |
elif (model_name == 'Buildings, Dwellings & Furniture'): | |
build_learn = import_model('buildings.pkl') | |
preds = build_learn.predict(img) | |
elif (model_name == 'Animals'): | |
anim_learn = import_model('animals.pkl') | |
preds = anim_learn.predict(img) | |
label_pred = str(preds[0]) | |
return label_pred | |
######################################################################### | |
title = "TM-TKO Trademark Logo Image Classification Model" | |
description = "Users can upload an image and corresponding image file name to get US design-code standard predictions on a trained model that utilizes the benchmark ResNet50 architecture." | |
iFace = gr.Interface(fn=predict, | |
inputs=[gr.inputs.Image(label="Upload Logo Here"), gr.inputs.Dropdown(choices=['Geometric Figures & Solids', 'Scenery, Natural Phenomena', 'Human & Supernatural Beings', 'Colors & Characters', 'Buildings, Dwellings & Furniture', 'Animals'], label='Choose a Model')], | |
outputs=gr.Label(label="TM-TKO Trademark Classification Model"), | |
title=title, description=description) | |
iFace.launch() |