Jon Solow
Refactor name to huggingface_client
32804a1
raw
history blame
572 Bytes
from .model import MODEL
from .handler import handle_file, handle_url
from .labels import CLASS_LABELS
def predict_model(img_array, n_top_guesses:int = 10):
class_prob = MODEL.predict(img_array)
top_values_index = (-class_prob).argsort()[0][:n_top_guesses]
top_guesses = [CLASS_LABELS[i].title() for i in top_values_index]
return top_guesses
def predict_file(file):
handled_array = handle_file(file)
return predict_model(handled_array)
def predict_url(url: str):
handled_array = handle_url(url)
return predict_model(handled_array)