nivashuggingface's picture
Upload app.py with huggingface_hub
ff06826 verified
raw
history blame
4.96 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import io
import os
import requests
import tempfile
# Function to download the model from Hugging Face
def download_model_from_hf(model_path, local_dir):
"""Download model files from Hugging Face"""
# Create a temporary directory to store the model
os.makedirs(local_dir, exist_ok=True)
# Extract the repository and file path from the URL
# Example URL: https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model
parts = model_path.split('/')
repo_id = f"{parts[3]}/{parts[4]}"
file_path = '/'.join(parts[6:])
# Download the model files
api_url = f"https://huggingface.co/api/models/{repo_id}/revision/main/files/{file_path}"
response = requests.get(api_url)
if response.status_code == 200:
# Download the saved_model.pb file
saved_model_pb_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/saved_model.pb"
pb_response = requests.get(saved_model_pb_url)
if pb_response.status_code == 200:
with open(os.path.join(local_dir, "saved_model.pb"), "wb") as f:
f.write(pb_response.content)
# Download the variables directory
variables_dir = os.path.join(local_dir, "variables")
os.makedirs(variables_dir, exist_ok=True)
# Download variables.data-00000-of-00001
variables_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.data-00000-of-00001"
var_response = requests.get(variables_url)
if var_response.status_code == 200:
with open(os.path.join(variables_dir, "variables.data-00000-of-00001"), "wb") as f:
f.write(var_response.content)
# Download variables.index
index_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.index"
index_response = requests.get(index_url)
if index_response.status_code == 200:
with open(os.path.join(variables_dir, "variables.index"), "wb") as f:
f.write(index_response.content)
return True
else:
print(f"Failed to download model: {response.status_code}")
return False
# Create a temporary directory for the model
MODEL_PATH = "https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model"
LOCAL_MODEL_DIR = os.path.join(tempfile.gettempdir(), "digit_recognition_model")
# Download the model if it doesn't exist locally
if not os.path.exists(os.path.join(LOCAL_MODEL_DIR, "saved_model.pb")):
print("Downloading model from Hugging Face...")
download_model_from_hf(MODEL_PATH, LOCAL_MODEL_DIR)
# Load the model from local directory
print(f"Loading model from {LOCAL_MODEL_DIR}")
model = tf.saved_model.load(LOCAL_MODEL_DIR)
def preprocess_image(img):
"""Preprocess the drawn image for prediction"""
# Convert to grayscale and resize
img = img.convert('L')
img = img.resize((28, 28))
# Convert to numpy array and normalize
img_array = np.array(img)
img_array = img_array.astype('float32') / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
# Add channel dimension
img_array = np.expand_dims(img_array, axis=-1)
return img_array
def predict_digit(img):
"""Predict digit from drawn image"""
try:
# Preprocess the image
processed_img = preprocess_image(img)
# Make prediction
predictions = model(processed_img)
predicted_digit = tf.argmax(predictions, axis=1).numpy()[0]
# Get confidence scores
confidence_scores = tf.nn.softmax(predictions[0]).numpy()
# Create result string
result = f"Predicted Digit: {predicted_digit}\n\nConfidence Scores:\n"
for i, score in enumerate(confidence_scores):
result += f"Digit {i}: {score:.2%}\n"
return result
except Exception as e:
return f"Error during prediction: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=predict_digit,
inputs=gr.Image(type="pil", label="Draw a digit (0-9)"),
outputs=gr.Textbox(label="Prediction Results"),
title="Digit Recognition with CNN",
description="""
Draw a digit (0-9) in the box below. The model will predict which digit you drew.
Instructions:
1. Click and drag to draw a digit
2. Make sure the digit is clear and centered
3. The model will show the predicted digit and confidence scores
""",
examples=[
["examples/0.png"],
["examples/1.png"],
["examples/2.png"],
],
theme=gr.themes.Soft(),
allow_flagging="never"
)
# Launch the interface
if __name__ == "__main__":
iface.launch()