File size: 3,572 Bytes
8c16ebc
 
 
d905150
 
4e2e389
d905150
 
 
8c16ebc
c0ca22c
d905150
 
 
 
 
 
58b426c
 
 
d905150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b40ecd
d905150
cd397c1
14f2a96
d905150
 
7e182e3
 
 
 
 
 
 
 
 
 
d905150
51439bf
 
14f2a96
d4d25d3
 
d905150
 
 
 
 
 
d921f3c
c0ca22c
d4d25d3
00c13b0
7e182e3
14f2a96
7e182e3
d905150
 
 
 
 
 
 
14f2a96
d905150
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import cv2
import gdown
import gradio as gr
import tensorflow as tf
import urllib.request
import numpy as np
import keras.backend as K

from PIL import Image
from matplotlib import cm

#from tensorflow import keras

resized_shape = (768, 768, 3)
IMG_SCALING = (1, 1)

# def get_opencv_img_from_buffer(buffer, flags=cv2.IMREAD_COLOR):
#     bytes_as_np_array = np.frombuffer(buffer.read(), dtype=np.uint8)
#     return cv2.imdecode(bytes_as_np_array, flags)

# Download the model file
def download_model():
    url = "https://drive.google.com/uc?id=1FhICkeGn6GcNXWTDn1s83ctC-6Mo1UXk"
    output = "seg_unet_model.h5"
    gdown.download(url, output, quiet=False)
    return output

model_file = download_model()

#Custom objects for model

def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
    targets = tf.dtypes.cast(K.flatten(y_true), tf.float32)
    inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32)
    
    intersection = K.sum(targets * inputs)
    dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
    inputs = K.clip(inputs, eps, 1.0 - eps)
    out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs))))
    weighted_ce = K.mean(out, axis=-1)
    combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)
    return combo

def dice_coef(y_true, y_pred, smooth=1):
    y_pred = tf.dtypes.cast(y_pred, tf.int32)
    y_true = tf.dtypes.cast(y_true, tf.int32)
    intersection = K.sum(y_true * y_pred, axis=[1,2,3])                     
    union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])           
    return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)

# Load the model
seg_model = tf.keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'dice_coef': dice_coef})

inputs = gr.inputs.Image(type="pil", label="Upload an image", source="upload")
image_output = gr.outputs.Image(type="numpy", label="Output Image")
# outputs = gr.outputs.HTML() #uncomment for single class output 

def fig2img(fig):
    """Convert a Matplotlib figure to a PIL Image and return it"""
    import io
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


def gen_pred(img=inputs, model=seg_model):
    # rgb_path = os.path.join(test_image_dir,img)
    # img = cv2.imread(rgb_path)
    img = cv2.imread("./003e2c95d.jpg")
    # pil_image = Image.open('./003b50a15.jpg')
    # img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
    img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = img/255
    img = tf.expand_dims(img, axis=0)
    pred = model.predict(img)
    pred = np.squeeze(pred, axis=0)
    print(pred)
    pred = Image.fromarray(np.uint8(cm.gist_earth(pred)*255))
    # color_coverted = cv2.cvtColor(pred, cv2.COLOR_BGR2RGB)
    # pil_image = Image.fromarray(pred)
    # PIL_image = Image.fromarray(pred.astype('uint8'), 'RGB')
    # return "UI in developing process ..."
    return pred

title = "<h1 style='text-align: center;'>Semantic Segmentation</h1>"
description = "Upload an image and get prediction mask"
# css_code='body{background-image:url("file=wave.mp4");}'

gr.Interface(fn=gen_pred, 
             inputs=inputs, 
             outputs="image", 
             title=title, 
             examples=[["003e2c95d.jpg"], ["003b50a15.jpg"], ["003b48a9e.jpg"], ["0038cbe45.jpg"], ["00371aa92.jpg"]],
             # css=css_code,
             description=description,
            enable_queue=True).launch()