Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import keras.backend as K
|
|
7 |
|
8 |
from matplotlib import pyplot as plt
|
9 |
from PIL import Image
|
10 |
-
import keras
|
11 |
|
12 |
|
13 |
resized_shape = (768, 768, 3)
|
@@ -21,14 +21,13 @@ IMG_SCALING = (1, 1)
|
|
21 |
# gdown.download(url, output, quiet=False)
|
22 |
# return output
|
23 |
|
24 |
-
|
25 |
-
model_file = 'seg_unet_model.h5'
|
26 |
|
27 |
#Custom objects for model
|
28 |
|
29 |
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
|
30 |
-
targets =
|
31 |
-
inputs =
|
32 |
intersection = K.sum(targets * inputs)
|
33 |
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
|
34 |
inputs = K.clip(inputs, eps, 1.0 - eps)
|
@@ -38,20 +37,20 @@ def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
|
|
38 |
return combo
|
39 |
|
40 |
def dice_coef(y_true, y_pred, smooth=1):
|
41 |
-
y_pred =
|
42 |
-
y_true =
|
43 |
intersection = K.sum(y_true * y_pred, axis=[1,2,3])
|
44 |
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
|
45 |
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
|
46 |
|
47 |
def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
|
48 |
-
pt_1 =
|
49 |
-
pt_0 =
|
50 |
focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
|
51 |
return focal_loss_fixed
|
52 |
|
53 |
# Load the model
|
54 |
-
seg_model = keras.models.load_model(
|
55 |
|
56 |
# inputs = gr.inputs.Image(type="pil", label="Upload an image")
|
57 |
# image_output = gr.outputs.Image(type="pil", label="Output Image")
|
@@ -68,7 +67,7 @@ def gen_pred(img, model=seg_model):
|
|
68 |
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
|
69 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
70 |
img = img/255
|
71 |
-
img =
|
72 |
pred = model.predict(img)
|
73 |
pred = np.squeeze(pred, axis=0)
|
74 |
fig = plt.figure(figsize=(3, 3))
|
|
|
7 |
|
8 |
from matplotlib import pyplot as plt
|
9 |
from PIL import Image
|
10 |
+
from tensorflow import keras
|
11 |
|
12 |
|
13 |
resized_shape = (768, 768, 3)
|
|
|
21 |
# gdown.download(url, output, quiet=False)
|
22 |
# return output
|
23 |
|
24 |
+
model_file = "./seg_unet_model.h5"
|
|
|
25 |
|
26 |
#Custom objects for model
|
27 |
|
28 |
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
|
29 |
+
targets = tf.dtypes.cast(K.flatten(y_true), tf.float32)
|
30 |
+
inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32)
|
31 |
intersection = K.sum(targets * inputs)
|
32 |
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
|
33 |
inputs = K.clip(inputs, eps, 1.0 - eps)
|
|
|
37 |
return combo
|
38 |
|
39 |
def dice_coef(y_true, y_pred, smooth=1):
|
40 |
+
y_pred = tf.dtypes.cast(y_pred, tf.int32)
|
41 |
+
y_true = tf.dtypes.cast(y_true, tf.int32)
|
42 |
intersection = K.sum(y_true * y_pred, axis=[1,2,3])
|
43 |
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
|
44 |
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
|
45 |
|
46 |
def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
|
47 |
+
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
|
48 |
+
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
|
49 |
focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
|
50 |
return focal_loss_fixed
|
51 |
|
52 |
# Load the model
|
53 |
+
seg_model = keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef})
|
54 |
|
55 |
# inputs = gr.inputs.Image(type="pil", label="Upload an image")
|
56 |
# image_output = gr.outputs.Image(type="pil", label="Output Image")
|
|
|
67 |
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
|
68 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
69 |
img = img/255
|
70 |
+
img = tf.expand_dims(img, axis=0)
|
71 |
pred = model.predict(img)
|
72 |
pred = np.squeeze(pred, axis=0)
|
73 |
fig = plt.figure(figsize=(3, 3))
|