Spaces:
Runtime error
Runtime error
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
def iou(y_true, y_pred): | |
def f(y_true, y_pred): | |
intersection = (y_true * y_pred).sum() | |
union = y_true.sum() + y_pred.sum() - intersection | |
x = (intersection + 1e-15) / (union + 1e-15) | |
x = x.astype(np.float32) | |
return x | |
return tf.numpy_function(f, [y_true, y_pred], tf.float32) | |
def dice_coef(y_true, y_pred): | |
y_true = tf.keras.layers.Flatten()(y_true) | |
y_pred = tf.keras.layers.Flatten()(y_pred) | |
intersection = tf.reduce_sum(y_true * y_pred) | |
return (2. * intersection + 1e-15) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred)) | |
def dice_loss(y_true, y_pred): | |
return 1.0 - dice_coef(y_true, y_pred) | |
def read_image(file, target_size=(256, 256)): | |
img = Image.open(file).convert('RGB') | |
img = img.resize(target_size) | |
x = np.array(img, dtype=np.float32) | |
x = x / 255.0 | |
return x | |
def preprocess_image(img): | |
if img.shape[-1] == 4: | |
img = img[..., :3] | |
img_expanded = np.expand_dims(img, axis=0) | |
return img_expanded | |
def predict_image(model, img): | |
pred = model.predict(img) | |
return pred | |
def visualize_prediction(img, pred): | |
fig, axs = plt.subplots(1, 2, figsize=(10, 5)) | |
axs[0].imshow(img) | |
axs[0].set_title('Original Image') | |
axs[1].imshow(pred[0, ...], cmap='gray') # Assuming the prediction is a mask or similar | |
axs[1].set_title('Predicted Image') | |
plt.close(fig) | |
return fig | |
# Load the model with custom loss and metric | |
model = tf.keras.models.load_model("oryx_road_segmentation_model.h5", custom_objects={'dice_coef': dice_coef, 'iou': iou}) | |
def process_image(image): | |
img = read_image(image) | |
img_preprocessed = preprocess_image(img) | |
pred = predict_image(model, img_preprocessed) | |
return visualize_prediction(img, pred) | |
iface = gr.Interface(fn=process_image, inputs="file", outputs="plot", title="orYx Models - Road Segmentation") | |
iface.launch() | |