import gradio as gr import tensorflow as tf from tensorflow.keras.applications.resnet import ResNet152, preprocess_input, decode_predictions from tensorflow.keras.preprocessing.image import img_to_array from PIL import Image import numpy as np import base64 from io import BytesIO # Load the pre-trained ResNet152 model MODEL_PATH = "resnet152-image-classifier.h5" # Path to the saved model try: model = tf.keras.models.load_model(MODEL_PATH) except Exception as e: print(f"Error loading model: {e}") exit() def decode_image_from_base64(base64_str): """ Decodes a base64 string to a PIL image. """ # Decode the base64 string to bytes image_data = base64.b64decode(base64_str) # Convert the bytes into a PIL image image = Image.open(BytesIO(image_data)) return image def predict_image(image): """ Process the uploaded image and return the top 3 predictions. """ try: # If the image is base64 encoded, decode it if isinstance(image, str): image = decode_image_from_base64(image) # Preprocess the image image = image.resize((224, 224)) # ResNet152 expects 224x224 input image_array = img_to_array(image) image_array = preprocess_input(image_array) # Normalize the image image_array = np.expand_dims(image_array, axis=0) # Add batch dimension # Get predictions predictions = model.predict(image_array) decoded_predictions = decode_predictions(predictions, top=3)[0] # Format predictions as a list of tuples (label, confidence) results = [(label, float(confidence)) for _, label, confidence in decoded_predictions] return dict(results) except Exception as e: return {"Error": str(e)} # Create the Gradio interface interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="pil", tool="editor"), # Accepts an image input outputs=gr.Label(num_top_classes=3), # Shows top 3 predictions with confidence title="ResNet152 Image Classifier", description="Upload an image, and the model will predict what's in the image.", examples=["dog.jpg", "cat.jpg"], # Example images for users to test ) # Launch the Gradio app if __name__ == "__main__": interface.launch()