yuragoithf's picture
Upload 7 files
d905150
raw
history blame
2.65 kB
import gradio as gr
import tensorflow as tf
import gdown
from PIL import Image
import os
import cv2
import numpy as np
import keras.backend as K
#from tensorflow import keras
input_shape = (32, 32, 3)
resized_shape = (768, 768, 3)
IMG_SCALING = (1, 1)
# Download the model file
def download_model():
url = "https://drive.google.com/uc?id=1FhICkeGn6GcNXWTDn1s83ctC-6Mo1UXk"
output = "seg_unet_model.h5"
gdown.download(url, output, quiet=False)
return output
model_file = download_model()
#Custom objects for model
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
targets = tf.dtypes.cast(K.flatten(y_true), tf.float32)
inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32)
intersection = K.sum(targets * inputs)
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
inputs = K.clip(inputs, eps, 1.0 - eps)
out = - (ALPHA * ((targets * K.log(inputs)) + ((1 - ALPHA) * (1.0 - targets) * K.log(1.0 - inputs))))
weighted_ce = K.mean(out, axis=-1)
combo = (CE_RATIO * weighted_ce) - ((1 - CE_RATIO) * dice)
return combo
def dice_coef(y_true, y_pred, smooth=1):
y_pred = tf.dtypes.cast(y_pred, tf.int32)
y_true = tf.dtypes.cast(y_true, tf.int32)
intersection = K.sum(y_true * y_pred, axis=[1,2,3])
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
# Load the model
seg_model = tf.keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'dice_coef': dice_coef}
inputs = gr.inputs.Image(type="pil", label="Upload an image")
# outputs = gr.outputs.HTML() #uncomment for single class output
def gen_pred(img=inputs, model=seg_model):
#rgb_path = os.path.join(test_image_dir,img)
#img = cv2.imread(rgb_path)
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img/255
img = tf.expand_dims(img, axis=0)
pred = model.predict(img)
pred = np.squeeze(pred, axis=0)
return inputs, pred
title = "<h1 style='text-align: center;'>Semantic Segmentation</h1>"
description = "Upload an image and get prediction mask"
# css_code='body{background-image:url("file=wave.mp4");}'
gr.Interface(fn=gen_pred,
inputs=inputs,
outputs="image",
title=title,
examples=[["003e2c95d.jpg"], ["003b50a15.jpg"], ["003b48a9e.jpg"], ["0038cbe45.jpg"], ["00371aa92.jpg"]],
# css=css_code,
description=description,
enable_queue=True).launch()