TM-TKO-Model-UI / app.py
amandasarubbi's picture
test
cda3545
'''
Model Gradio UI
'''
#########################################################################
# imports
from fastai.vision.all import *
import gradio as gr
import pathlib
from huggingface_hub import hf_hub_download
#########################################################################
# user access token for HF model library
ACCESS_TOKEN = "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"
#########################################################################
#Consider path seperators for alternate OS
plt = platform.system()
if plt != 'Windows': pathlib.WindowsPath = pathlib.PosixPath
#########################################################################
def import_model(model_name):
path = hf_hub_download(repo_id='amandasarubbi/tm-tko-models',
filename=model_name,
use_auth_token=ACCESS_TOKEN,
repo_type='model')
learn = load_learner(path, cpu=True)
return learn
#########################################################################
#########################################################################
# Function to predict outputs
def predict(img, model_name):
if (model_name == 'Geometric Figures & Solids'):
geo_learn = import_model('geometric_model.pkl')
preds = geo_learn.predict(img)
elif (model_name == 'Scenery, Natural Phenomena'):
landscape_learn = import_model('landscape_model.pkl')
preds = landscape_learn.predict(img)
elif (model_name == 'Human & Supernatural Beings'):
human_learn = import_model('human_model.pkl')
preds = human_learn.predict(img)
elif (model_name == 'Colors & Characters'):
colors_learn = import_model('colors_model.pkl')
preds = colors_learn.predict(img)
elif (model_name == 'Buildings, Dwellings & Furniture'):
build_learn = import_model('buildings.pkl')
preds = build_learn.predict(img)
elif (model_name == 'Animals'):
anim_learn = import_model('animals.pkl')
preds = anim_learn.predict(img)
label_pred = str(preds[0])
return label_pred
#########################################################################
title = "TM-TKO Trademark Logo Image Classification Model"
description = "Users can upload an image and corresponding image file name to get US design-code standard predictions on a trained model that utilizes the benchmark ResNet50 architecture."
iFace = gr.Interface(fn=predict,
inputs=[gr.inputs.Image(label="Upload Logo Here"), gr.inputs.Dropdown(choices=['Geometric Figures & Solids', 'Scenery, Natural Phenomena', 'Human & Supernatural Beings', 'Colors & Characters', 'Buildings, Dwellings & Furniture', 'Animals'], label='Choose a Model')],
outputs=gr.Label(label="TM-TKO Trademark Classification Model"),
title=title, description=description)
iFace.launch()