Sketch / app.py
Jangai's picture
Update app.py
a57d0bc verified
raw
history blame
2.1 kB
import gradio as gr
import tensorflow as tf
from transformers import AutoFeatureExtractor
from PIL import Image
import numpy as np
import logging
# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Load the pre-trained model and feature extractor
model_name = "hoangthan/image-classification"
logging.info("Loading image processor and model...")
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = tf.keras.models.load_model('https://huggingface.co/hoangthan/image-classification/resolve/main/tf_model.h5')
# Define the prediction function
def predict(image):
try:
logging.info("Received image of type: %s", type(image))
logging.debug("Image content: %s", image)
# Use the 'composite' key to get the final image
if isinstance(image, dict):
image = image['composite']
logging.debug("Converting to NumPy array...")
image = np.array(image).astype('uint8')
logging.debug("Converting NumPy array to PIL image...")
image = Image.fromarray(image, 'RGBA').convert('RGB')
logging.debug("Image converted successfully.")
# Process the image for the model
inputs = feature_extractor(images=image, return_tensors="np")
pixel_values = inputs['pixel_values'][0]
# Predict using the model
preds = model.predict(np.expand_dims(pixel_values, axis=0))
top_probs = tf.nn.softmax(preds[0])
top_idxs = np.argsort(-top_probs)[:3]
top_classes = [model.config.id2label[idx] for idx in top_idxs]
result = {top_classes[i]: float(top_probs[i]) for i in range(3)}
logging.info("Prediction successful.")
return result
except Exception as e:
logging.error("Error during prediction: %s", e)
return str(e)
# Create the Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Sketchpad(),
outputs=gr.JSON(),
title="Drawing Classifier",
description="Draw something and the model will try to identify it!"
)
# Launch the interface
iface.launch()