|
import gradio as gr |
|
import tensorflow as tf |
|
from tensorflow.keras.applications.resnet import ResNet152, preprocess_input, decode_predictions |
|
from tensorflow.keras.preprocessing.image import img_to_array |
|
from PIL import Image |
|
import numpy as np |
|
import base64 |
|
from io import BytesIO |
|
|
|
|
|
MODEL_PATH = "resnet152-image-classifier.h5" |
|
try: |
|
model = tf.keras.models.load_model(MODEL_PATH) |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
exit() |
|
|
|
def decode_image_from_base64(base64_str): |
|
""" |
|
Decodes a base64 string to a PIL image. |
|
""" |
|
|
|
image_data = base64.b64decode(base64_str) |
|
|
|
image = Image.open(BytesIO(image_data)) |
|
return image |
|
|
|
def predict_image(image): |
|
""" |
|
Process the uploaded image and return the top 3 predictions. |
|
""" |
|
try: |
|
|
|
if isinstance(image, str): |
|
image = decode_image_from_base64(image) |
|
|
|
|
|
image = image.resize((224, 224)) |
|
image_array = img_to_array(image) |
|
image_array = preprocess_input(image_array) |
|
image_array = np.expand_dims(image_array, axis=0) |
|
|
|
|
|
predictions = model.predict(image_array) |
|
decoded_predictions = decode_predictions(predictions, top=3)[0] |
|
|
|
|
|
results = [(label, float(confidence)) for _, label, confidence in decoded_predictions] |
|
return dict(results) |
|
|
|
except Exception as e: |
|
return {"Error": str(e)} |
|
|
|
|
|
interface = gr.Interface( |
|
fn=predict_image, |
|
inputs=gr.Image(type="pil", tool="editor"), |
|
outputs=gr.Label(num_top_classes=3), |
|
title="ResNet152 Image Classifier", |
|
description="Upload an image, and the model will predict what's in the image.", |
|
examples=["dog.jpg", "cat.jpg"], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|