ftx7go's picture
Update app.py
615f871 verified
raw
history blame
1.44 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from tensorflow.keras.preprocessing import image
from PIL import Image
import os
# Load the trained model
model = tf.keras.models.load_model("my_keras_model.h5")
# Define image size based on the model's input requirement
image_size = (224, 224)
# Function to make predictions
def predict_image(img):
img = img.resize(image_size) # Resize image to model's expected size
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) / 255.0 # Normalize
prediction = model.predict(img_array)
# Assuming binary classification (fractured or normal)
class_names = ['Fractured', 'Normal']
predicted_class = class_names[int(prediction[0] > 0.5)] # Threshold at 0.5
return f"Prediction: {predicted_class} (Confidence: {prediction[0][0]:.2f})"
# Get image paths dynamically
sample_images_dir = "samples"
sample_images = [os.path.join(sample_images_dir, f) for f in os.listdir(sample_images_dir) if f.endswith(('.jpg', '.png'))]
# Define Gradio Interface
interface = gr.Interface(
fn=predict_image,
inputs=gr.Image(type="pil"),
outputs=gr.Textbox(),
examples=sample_images, # Preloaded images for testing
title="Bone Fracture Detection",
description="Upload an X-ray image or select a sample image to check for fractures."
)
# Launch the Gradio app
if __name__ == "__main__":
interface.launch()