Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import io | |
# Load the model from Hugging Face | |
MODEL_PATH = "https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model" | |
model = tf.saved_model.load(MODEL_PATH) | |
def preprocess_image(img): | |
"""Preprocess the drawn image for prediction""" | |
# Convert to grayscale and resize | |
img = img.convert('L') | |
img = img.resize((28, 28)) | |
# Convert to numpy array and normalize | |
img_array = np.array(img) | |
img_array = img_array.astype('float32') / 255.0 | |
# Add batch dimension | |
img_array = np.expand_dims(img_array, axis=0) | |
# Add channel dimension | |
img_array = np.expand_dims(img_array, axis=-1) | |
return img_array | |
def predict_digit(img): | |
"""Predict digit from drawn image""" | |
try: | |
# Preprocess the image | |
processed_img = preprocess_image(img) | |
# Make prediction | |
predictions = model(processed_img) | |
predicted_digit = tf.argmax(predictions, axis=1).numpy()[0] | |
# Get confidence scores | |
confidence_scores = tf.nn.softmax(predictions[0]).numpy() | |
# Create result string | |
result = f"Predicted Digit: {predicted_digit}\n\nConfidence Scores:\n" | |
for i, score in enumerate(confidence_scores): | |
result += f"Digit {i}: {score:.2%}\n" | |
return result | |
except Exception as e: | |
return f"Error during prediction: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_digit, | |
inputs=gr.Image(type="pil", label="Draw a digit (0-9)"), | |
outputs=gr.Textbox(label="Prediction Results"), | |
title="Digit Recognition with CNN", | |
description=""" | |
Draw a digit (0-9) in the box below. The model will predict which digit you drew. | |
Instructions: | |
1. Click and drag to draw a digit | |
2. Make sure the digit is clear and centered | |
3. The model will show the predicted digit and confidence scores | |
""", | |
examples=[ | |
["examples/0.png"], | |
["examples/1.png"], | |
["examples/2.png"], | |
], | |
theme=gr.themes.Soft(), | |
allow_flagging="never" | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |