import gradio as gr import tensorflow as tf import numpy as np from PIL import Image import io import os import requests import tempfile # Function to download the model from Hugging Face def download_model_from_hf(model_path, local_dir): """Download model files from Hugging Face""" # Create a temporary directory to store the model os.makedirs(local_dir, exist_ok=True) # Extract the repository and file path from the URL # Example URL: https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model parts = model_path.split('/') repo_id = f"{parts[3]}/{parts[4]}" file_path = '/'.join(parts[6:]) # Download the model files api_url = f"https://huggingface.co/api/models/{repo_id}/revision/main/files/{file_path}" response = requests.get(api_url) if response.status_code == 200: # Download the saved_model.pb file saved_model_pb_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/saved_model.pb" pb_response = requests.get(saved_model_pb_url) if pb_response.status_code == 200: with open(os.path.join(local_dir, "saved_model.pb"), "wb") as f: f.write(pb_response.content) # Download the variables directory variables_dir = os.path.join(local_dir, "variables") os.makedirs(variables_dir, exist_ok=True) # Download variables.data-00000-of-00001 variables_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.data-00000-of-00001" var_response = requests.get(variables_url) if var_response.status_code == 200: with open(os.path.join(variables_dir, "variables.data-00000-of-00001"), "wb") as f: f.write(var_response.content) # Download variables.index index_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.index" index_response = requests.get(index_url) if index_response.status_code == 200: with open(os.path.join(variables_dir, "variables.index"), "wb") as f: f.write(index_response.content) return True else: print(f"Failed to download model: {response.status_code}") return False # Create a temporary directory for the model MODEL_PATH = "https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model" LOCAL_MODEL_DIR = os.path.join(tempfile.gettempdir(), "digit_recognition_model") # Download the model if it doesn't exist locally if not os.path.exists(os.path.join(LOCAL_MODEL_DIR, "saved_model.pb")): print("Downloading model from Hugging Face...") download_model_from_hf(MODEL_PATH, LOCAL_MODEL_DIR) # Load the model from local directory print(f"Loading model from {LOCAL_MODEL_DIR}") model = tf.saved_model.load(LOCAL_MODEL_DIR) def preprocess_image(img): """Preprocess the drawn image for prediction""" # Convert to grayscale and resize img = img.convert('L') img = img.resize((28, 28)) # Convert to numpy array and normalize img_array = np.array(img) img_array = img_array.astype('float32') / 255.0 # Add batch dimension img_array = np.expand_dims(img_array, axis=0) # Add channel dimension img_array = np.expand_dims(img_array, axis=-1) return img_array def predict_digit(img): """Predict digit from drawn image""" try: # Preprocess the image processed_img = preprocess_image(img) # Make prediction predictions = model(processed_img) predicted_digit = tf.argmax(predictions, axis=1).numpy()[0] # Get confidence scores confidence_scores = tf.nn.softmax(predictions[0]).numpy() # Create result string result = f"Predicted Digit: {predicted_digit}\n\nConfidence Scores:\n" for i, score in enumerate(confidence_scores): result += f"Digit {i}: {score:.2%}\n" return result except Exception as e: return f"Error during prediction: {str(e)}" # Create Gradio interface iface = gr.Interface( fn=predict_digit, inputs=gr.Image(type="pil", label="Draw a digit (0-9)"), outputs=gr.Textbox(label="Prediction Results"), title="Digit Recognition with CNN", description=""" Draw a digit (0-9) in the box below. The model will predict which digit you drew. Instructions: 1. Click and drag to draw a digit 2. Make sure the digit is clear and centered 3. The model will show the predicted digit and confidence scores """, examples=[ ["examples/0.png"], ["examples/1.png"], ["examples/2.png"], ], theme=gr.themes.Soft(), allow_flagging="never" ) # Launch the interface if __name__ == "__main__": iface.launch()