File size: 1,299 Bytes
e452c35
a6546e9
 
8d0de00
a6546e9
 
 
 
 
 
 
8d0de00
a6546e9
 
8d0de00
 
 
 
 
a6546e9
 
d93c84c
 
8d0de00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import tensorflow as tf
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from huggingface_hub import from_pretrained_keras
import os
import sys

print('Loading model...')
model = from_pretrained_keras("mostafapasha/ribs-segmentation-model", compile=False)
print('Successfully loaded model...')
examples = ['examples/VinDr_RibCXR_train_008.png', 'examples/VinDr_RibCXR_train_013.png']


def infer(img, threshold):
    if np.ndim(img) != 2:
        img = img[:, :, 1]
    img = img.reshape(1, img.shape[0], img.shape[1], 1)
    logits = model(img, training=False)
    prob = tf.sigmoid(logits)
    pred = tf.cast(prob > threshold, dtype=tf.float32)
    pred = np.array(pred.numpy())[0,:,:,0]
    return pred

gr_input = [
    gr.inputs.Image(label="Image", type="numpy", shape=(512, 512))
    ,gr.inputs.Slider(minimum=0, maximum=1, step=0.05, default=0.5, label="Segmentation Threshold")
]

gr_output = [
    gr.outputs.Image(type="pil",label="Segmentation Mask"),
    # gr.outputs.Image(type="pil",label="Filtered Image"),
]

iface = gr.Interface(fn=infer,
title = 'ribs segmentation model',
description = 'Keras implementation of ResUNET++ for xray ribs segmentation',
inputs=gr_input,
outputs=gr_output, examples=examples, flagging_dir="flagged").launch(cache_examples=True)