File size: 3,437 Bytes
1899d85
8c16ebc
d905150
 
 
 
 
ea62197
8c16ebc
f2264dc
d905150
 
 
 
 
 
7aeba2e
 
 
 
 
 
d905150
1b88331
bc79d00
d905150
 
 
 
7fe3aed
 
d905150
 
 
 
 
 
 
 
 
7fe3aed
 
d905150
 
 
 
3b78676
7fe3aed
 
3b78676
 
 
d905150
bc79d00
d905150
d22dc1f
ea62197
d905150
 
67680bf
 
7e182e3
8a7486c
8b7ee47
 
 
 
d905150
 
 
73747ef
d905150
5a36cdc
7058c31
67680bf
ea62197
 
c25a556
 
 
d905150
2302330
d905150
 
 
c48b854
5a36cdc
d905150
0e5a4ff
d905150
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os, io
import cv2
import gradio as gr
import tensorflow as tf
import numpy as np
import keras.backend as K

from matplotlib import pyplot as plt
from PIL import Image
import keras


resized_shape = (768, 768, 3)
IMG_SCALING = (1, 1)


# # Download the model file
# def download_model():
#     url = "https://drive.google.com/uc?id=1FhICkeGn6GcNXWTDn1s83ctC-6Mo1UXk"
#     output = "seg_unet_model.h5"
#     gdown.download(url, output, quiet=False)
#     return output

#model_file = "./seg_unet_model.h5"
model_file = 'seg_unet_model.h5'

#Custom objects for model

def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
    targets = keras.ops.cast(K.flatten(y_true), dtype="float32")
    inputs = keras.ops.cast(K.flatten(y_pred), dtype="float32")
    intersection = K.sum(targets * inputs)
    dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    inputs = K.clip(inputs, eps, 1.0 - eps)
    out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs))))
    weighted_ce = K.mean(out, axis=-1)
    combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)
    return combo

def dice_coef(y_true, y_pred, smooth=1):
    y_pred = keras.ops.cast(y_pred, dtype="int32")
    y_true = keras.ops.cast(y_true, tdtype="int32")
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])                     
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])           
    return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)

def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
    pt_1 = keras.ops.where(keras.ops.equal(y_true, 1), y_pred, keras.ops.ones_like(y_pred))
    pt_0 = keras.ops.where(keras.ops.equal(y_true, 0), y_pred, keras.ops.ones_like(y_pred))
    focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
    return focal_loss_fixed

# Load the model
seg_model = keras.models.load_model(model_file, custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef})

# inputs = gr.inputs.Image(type="pil", label="Upload an image")
# image_output = gr.outputs.Image(type="pil", label="Output Image")
# outputs = gr.outputs.HTML() #uncomment for single class output 

rows = 1
columns = 1

def gen_pred(img, model=seg_model):
    pil_image = img.convert('RGB')
    open_cv_image = np.array(pil_image)
    img = open_cv_image[:, :, ::-1].copy() 
    # img = cv2.imread("./003e2c95d.jpg")
    img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img/255
    img = keras.ops.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = np.squeeze(pred, axis=0)
    fig = plt.figure(figsize=(3, 3))
    fig.add_subplot(rows, columns, 1)
    # plt.imshow(pred, interpolation='catrom')
    plt.imshow(pred)
    plt.axis('off')
    plt.show()
    return fig

title = "<h1 style='text-align: center;'>Semantic Segmentation (Airbus Ship Detection Challenge)</h1>"
description = "Upload an image and get prediction mask"

gr.Interface(fn=gen_pred, 
             inputs=[gr.components.Image(type='pil')], 
             outputs=["plot"], 
             title=title, 
             examples=[["00c3db267.jpg"], ["00dc34840.jpg"], ["00371aa92.jpg"]],
             description=description,
            enable_queue=True).launch()