File size: 2,013 Bytes
8db8b38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt

def iou(y_true, y_pred):
    def f(y_true, y_pred):
        intersection = (y_true * y_pred).sum()
        union = y_true.sum() + y_pred.sum() - intersection
        x = (intersection + 1e-15) / (union + 1e-15)
        x = x.astype(np.float32)
        return x
    return tf.numpy_function(f, [y_true, y_pred], tf.float32)

def dice_coef(y_true, y_pred):
    y_true = tf.keras.layers.Flatten()(y_true)
    y_pred = tf.keras.layers.Flatten()(y_pred)
    intersection = tf.reduce_sum(y_true * y_pred)
    return (2. * intersection + 1e-15) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred))

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

def read_image(file, target_size=(256, 256)):
    img = Image.open(file).convert('RGB')
    img = img.resize(target_size)
    x = np.array(img, dtype=np.float32)
    x = x / 255.0
    return x

def preprocess_image(img):
    if img.shape[-1] == 4:
        img = img[..., :3]
    img_expanded = np.expand_dims(img, axis=0)
    return img_expanded

def predict_image(model, img):
    pred = model.predict(img)
    return pred

def visualize_prediction(img, pred):
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].imshow(img)
    axs[0].set_title('Original Image')
    axs[1].imshow(pred[0, ...], cmap='gray')  # Assuming the prediction is a mask or similar
    axs[1].set_title('Predicted Image')
    plt.close(fig)
    return fig

# Load the model with custom loss and metric
model = tf.keras.models.load_model("oryx_road_segmentation_model.h5", custom_objects={'dice_coef': dice_coef, 'iou': iou})

def process_image(image):
    img = read_image(image)
    img_preprocessed = preprocess_image(img)
    pred = predict_image(model, img_preprocessed)
    return visualize_prediction(img, pred)

iface = gr.Interface(fn=process_image, inputs="file", outputs="plot", title="orYx Models - Road Segmentation")
iface.launch()