Spaces:
Runtime error
Runtime error
File size: 2,013 Bytes
8db8b38 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
def iou(y_true, y_pred):
def f(y_true, y_pred):
intersection = (y_true * y_pred).sum()
union = y_true.sum() + y_pred.sum() - intersection
x = (intersection + 1e-15) / (union + 1e-15)
x = x.astype(np.float32)
return x
return tf.numpy_function(f, [y_true, y_pred], tf.float32)
def dice_coef(y_true, y_pred):
y_true = tf.keras.layers.Flatten()(y_true)
y_pred = tf.keras.layers.Flatten()(y_pred)
intersection = tf.reduce_sum(y_true * y_pred)
return (2. * intersection + 1e-15) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred))
def dice_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred)
def read_image(file, target_size=(256, 256)):
img = Image.open(file).convert('RGB')
img = img.resize(target_size)
x = np.array(img, dtype=np.float32)
x = x / 255.0
return x
def preprocess_image(img):
if img.shape[-1] == 4:
img = img[..., :3]
img_expanded = np.expand_dims(img, axis=0)
return img_expanded
def predict_image(model, img):
pred = model.predict(img)
return pred
def visualize_prediction(img, pred):
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(img)
axs[0].set_title('Original Image')
axs[1].imshow(pred[0, ...], cmap='gray') # Assuming the prediction is a mask or similar
axs[1].set_title('Predicted Image')
plt.close(fig)
return fig
# Load the model with custom loss and metric
model = tf.keras.models.load_model("oryx_road_segmentation_model.h5", custom_objects={'dice_coef': dice_coef, 'iou': iou})
def process_image(image):
img = read_image(image)
img_preprocessed = preprocess_image(img)
pred = predict_image(model, img_preprocessed)
return visualize_prediction(img, pred)
iface = gr.Interface(fn=process_image, inputs="file", outputs="plot", title="orYx Models - Road Segmentation")
iface.launch()
|