nivashuggingface's picture
Upload app.py with huggingface_hub
ff06826 verified
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import os
import requests
import tempfile
# Function to download the model from Hugging Face
def download_model_from_hf(model_path, local_dir):
"""Download model files from Hugging Face"""
# Create a temporary directory to store the model
os.makedirs(local_dir, exist_ok=True)
# Extract the repository and file path from the URL
# Example URL: https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model
parts = model_path.split('/')
repo_id = f"{parts[3]}/{parts[4]}"
file_path = '/'.join(parts[6:])
# Download the model files
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/main/files/{file_path}"
response = requests.get(api_url)
if response.status_code == 200:
# Download the saved_model.pb file
saved_model_pb_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/saved_model.pb"
pb_response = requests.get(saved_model_pb_url)
if pb_response.status_code == 200:
with open(os.path.join(local_dir, "saved_model.pb"), "wb") as f:
f.write(pb_response.content)
# Download the variables directory
variables_dir = os.path.join(local_dir, "variables")
os.makedirs(variables_dir, exist_ok=True)
# Download variables.data-00000-of-00001
variables_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.data-00000-of-00001"
var_response = requests.get(variables_url)
if var_response.status_code == 200:
with open(os.path.join(variables_dir, "variables.data-00000-of-00001"), "wb") as f:
f.write(var_response.content)
# Download variables.index
index_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.index"
index_response = requests.get(index_url)
if index_response.status_code == 200:
with open(os.path.join(variables_dir, "variables.index"), "wb") as f:
f.write(index_response.content)
return True
else:
print(f"Failed to download model: {response.status_code}")
return False
# Create a temporary directory for the model
MODEL_PATH = "https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model"
LOCAL_MODEL_DIR = os.path.join(tempfile.gettempdir(), "digit_recognition_model")
# Download the model if it doesn't exist locally
if not os.path.exists(os.path.join(LOCAL_MODEL_DIR, "saved_model.pb")):
print("Downloading model from Hugging Face...")
download_model_from_hf(MODEL_PATH, LOCAL_MODEL_DIR)
# Load the model from local directory
print(f"Loading model from {LOCAL_MODEL_DIR}")
model = tf.saved_model.load(LOCAL_MODEL_DIR)
def preprocess_image(img):
"""Preprocess the drawn image for prediction"""
# Convert to grayscale and resize
img = img.convert('L')
img = img.resize((28, 28))
# Convert to numpy array and normalize
img_array = np.array(img)
img_array = img_array.astype('float32') / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
# Add channel dimension
img_array = np.expand_dims(img_array, axis=-1)
return img_array
def predict_digit(img):
"""Predict digit from drawn image"""
try:
# Preprocess the image
processed_img = preprocess_image(img)
# Make prediction
predictions = model(processed_img)
predicted_digit = tf.argmax(predictions, axis=1).numpy()[0]
# Get confidence scores
confidence_scores = tf.nn.softmax(predictions[0]).numpy()
# Create result string
result = f"Predicted Digit: {predicted_digit}\n\nConfidence Scores:\n"
for i, score in enumerate(confidence_scores):
result += f"Digit {i}: {score:.2%}\n"
return result
except Exception as e:
return f"Error during prediction: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=predict_digit,
inputs=gr.Image(type="pil", label="Draw a digit (0-9)"),
outputs=gr.Textbox(label="Prediction Results"),
title="Digit Recognition with CNN",
description="""
Draw a digit (0-9) in the box below. The model will predict which digit you drew.
Instructions:
1. Click and drag to draw a digit
2. Make sure the digit is clear and centered
3. The model will show the predicted digit and confidence scores
""",
examples=[
["examples/0.png"],
["examples/1.png"],
["examples/2.png"],
],
theme=gr.themes.Soft(),
allow_flagging="never"
)
# Launch the interface
if __name__ == "__main__":
iface.launch()