import numpy as np import gradio as gr import tensorflow as tf import cv2 # App title title = "Welcome to your first sketch recognition app!" # App description head = ( "
" "" "

The model is trained to classify numbers (from 0 to 9). " "To test it, draw your number in the space provided.

" "
" ) # GitHub repository link ref = "Find the complete code [here](https://github.com/ovh/ai-training-examples/tree/main/apps/gradio/sketch-recognition)." # Class names (from 0 to 9) labels = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"] # Load model (trained on MNIST dataset) model = tf.keras.models.load_model("./sketch_recognition_numbers_model.h5") """ # Prediction function for sketch recognition def predict(data): print(data['composite'].shape) # Reshape image to 28x28 img = np.reshape(data['composite'], (1, img_size, img_size, 1)) # Make prediction pred = model.predict(img) # Get top class top_3_classes = np.argsort(pred[0])[-3:][::-1] # Get top 3 probabilities top_3_probs = pred[0][top_3_classes] # Get class names class_names = [labels[i] for i in top_3_classes] # Return class names and probabilities return {class_names[i]: top_3_probs[i] for i in range(3)} """ def predict(data): # Extract the 'image' key from the input dictionary img = data['image'] # Convert to NumPy array img = np.array(img) # Handle RGBA or RGB images if img.shape[-1] == 4: # RGBA img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB) if img.shape[-1] == 3: # RGB img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Resize image to 28x28 img = cv2.resize(img, (28, 28)) # Normalize pixel values to [0, 1] img = img / 255.0 # Reshape to match model input img = img.reshape(1, 28, 28, 1) # Model predictions preds = model.predict(img)[0] # Return the probability for each class return {label: float(pred) for label, pred in zip(labels, preds)} # Top 3 classes label = gr.Label(num_top_classes=3) # Open Gradio interface for sketch recognition interface = gr.Interface( fn=predict, inputs=gr.Sketchpad(type='numpy'), outputs=label, title=title, description=head, article=ref ) interface.launch(share=True)