Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tensorflow as tf | |
import numpy as np | |
from PIL import Image | |
import io | |
import os | |
import requests | |
import tempfile | |
# Function to download the model from Hugging Face | |
def download_model_from_hf(model_path, local_dir): | |
"""Download model files from Hugging Face""" | |
# Create a temporary directory to store the model | |
os.makedirs(local_dir, exist_ok=True) | |
# Extract the repository and file path from the URL | |
# Example URL: https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model | |
parts = model_path.split('/') | |
repo_id = f"{parts[3]}/{parts[4]}" | |
file_path = '/'.join(parts[6:]) | |
# Download the model files | |
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/main/files/{file_path}" | |
response = requests.get(api_url) | |
if response.status_code == 200: | |
# Download the saved_model.pb file | |
saved_model_pb_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/saved_model.pb" | |
pb_response = requests.get(saved_model_pb_url) | |
if pb_response.status_code == 200: | |
with open(os.path.join(local_dir, "saved_model.pb"), "wb") as f: | |
f.write(pb_response.content) | |
# Download the variables directory | |
variables_dir = os.path.join(local_dir, "variables") | |
os.makedirs(variables_dir, exist_ok=True) | |
# Download variables.data-00000-of-00001 | |
variables_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.data-00000-of-00001" | |
var_response = requests.get(variables_url) | |
if var_response.status_code == 200: | |
with open(os.path.join(variables_dir, "variables.data-00000-of-00001"), "wb") as f: | |
f.write(var_response.content) | |
# Download variables.index | |
index_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.index" | |
index_response = requests.get(index_url) | |
if index_response.status_code == 200: | |
with open(os.path.join(variables_dir, "variables.index"), "wb") as f: | |
f.write(index_response.content) | |
return True | |
else: | |
print(f"Failed to download model: {response.status_code}") | |
return False | |
# Create a temporary directory for the model | |
MODEL_PATH = "https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model" | |
LOCAL_MODEL_DIR = os.path.join(tempfile.gettempdir(), "digit_recognition_model") | |
# Download the model if it doesn't exist locally | |
if not os.path.exists(os.path.join(LOCAL_MODEL_DIR, "saved_model.pb")): | |
print("Downloading model from Hugging Face...") | |
download_model_from_hf(MODEL_PATH, LOCAL_MODEL_DIR) | |
# Load the model from local directory | |
print(f"Loading model from {LOCAL_MODEL_DIR}") | |
model = tf.saved_model.load(LOCAL_MODEL_DIR) | |
def preprocess_image(img): | |
"""Preprocess the drawn image for prediction""" | |
# Convert to grayscale and resize | |
img = img.convert('L') | |
img = img.resize((28, 28)) | |
# Convert to numpy array and normalize | |
img_array = np.array(img) | |
img_array = img_array.astype('float32') / 255.0 | |
# Add batch dimension | |
img_array = np.expand_dims(img_array, axis=0) | |
# Add channel dimension | |
img_array = np.expand_dims(img_array, axis=-1) | |
return img_array | |
def predict_digit(img): | |
"""Predict digit from drawn image""" | |
try: | |
# Preprocess the image | |
processed_img = preprocess_image(img) | |
# Make prediction | |
predictions = model(processed_img) | |
predicted_digit = tf.argmax(predictions, axis=1).numpy()[0] | |
# Get confidence scores | |
confidence_scores = tf.nn.softmax(predictions[0]).numpy() | |
# Create result string | |
result = f"Predicted Digit: {predicted_digit}\n\nConfidence Scores:\n" | |
for i, score in enumerate(confidence_scores): | |
result += f"Digit {i}: {score:.2%}\n" | |
return result | |
except Exception as e: | |
return f"Error during prediction: {str(e)}" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=predict_digit, | |
inputs=gr.Image(type="pil", label="Draw a digit (0-9)"), | |
outputs=gr.Textbox(label="Prediction Results"), | |
title="Digit Recognition with CNN", | |
description=""" | |
Draw a digit (0-9) in the box below. The model will predict which digit you drew. | |
Instructions: | |
1. Click and drag to draw a digit | |
2. Make sure the digit is clear and centered | |
3. The model will show the predicted digit and confidence scores | |
""", | |
examples=[ | |
["examples/0.png"], | |
["examples/1.png"], | |
["examples/2.png"], | |
], | |
theme=gr.themes.Soft(), | |
allow_flagging="never" | |
) | |
# Launch the interface | |
if __name__ == "__main__": | |
iface.launch() |