|
import gradio as gr |
|
import tensorflow as tf |
|
import gdown |
|
from PIL import Image |
|
|
|
import os |
|
import cv2 |
|
import numpy as np |
|
import keras.backend as K |
|
|
|
|
|
|
|
|
|
input_shape = (32, 32, 3) |
|
resized_shape = (768, 768, 3) |
|
IMG_SCALING = (1, 1) |
|
|
|
|
|
|
|
def download_model(): |
|
url = "https://drive.google.com/uc?id=1FhICkeGn6GcNXWTDn1s83ctC-6Mo1UXk" |
|
output = "seg_unet_model.h5" |
|
gdown.download(url, output, quiet=False) |
|
return output |
|
|
|
model_file = download_model() |
|
|
|
|
|
|
|
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1): |
|
targets = tf.dtypes.cast(K.flatten(y_true), tf.float32) |
|
inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32) |
|
|
|
intersection = K.sum(targets * inputs) |
|
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth) |
|
inputs = K.clip(inputs, eps, 1.0 - eps) |
|
out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs)))) |
|
weighted_ce = K.mean(out, axis=-1) |
|
combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice) |
|
return combo |
|
|
|
def dice_coef(y_true, y_pred, smooth=1): |
|
y_pred = tf.dtypes.cast(y_pred, tf.int32) |
|
y_true = tf.dtypes.cast(y_true, tf.int32) |
|
intersection = K.sum(y_true * y_pred, axis=[1,2,3]) |
|
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3]) |
|
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0) |
|
|
|
|
|
seg_model = tf.keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'dice_coef': dice_coef} |
|
|
|
|
|
|
|
|
|
|
|
inputs = gr.inputs.Image(type="pil", label="Upload an image") |
|
|
|
|
|
def gen_pred(img=inputs, model=seg_model): |
|
|
|
|
|
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]] |
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
img = img/255 |
|
img = tf.expand_dims(img, axis=0) |
|
pred = model.predict(img) |
|
pred = np.squeeze(pred, axis=0) |
|
return inputs, pred |
|
|
|
|
|
title = "<h1 style='text-align: center;'>Semantic Segmentation</h1>" |
|
description = "Upload an image and get prediction mask" |
|
|
|
|
|
gr.Interface(fn=gen_pred, |
|
inputs=inputs, |
|
outputs="image", |
|
title=title, |
|
examples=[["003e2c95d.jpg"], ["003b50a15.jpg"], ["003b48a9e.jpg"], ["0038cbe45.jpg"], ["00371aa92.jpg"]], |
|
|
|
description=description, |
|
enable_queue=True).launch() |