yuragoithf commited on
Commit
75d8751
·
verified ·
1 Parent(s): cda072c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -11
app.py CHANGED
@@ -7,7 +7,7 @@ import keras.backend as K
7
 
8
  from matplotlib import pyplot as plt
9
  from PIL import Image
10
- import keras
11
 
12
 
13
  resized_shape = (768, 768, 3)
@@ -21,14 +21,13 @@ IMG_SCALING = (1, 1)
21
  # gdown.download(url, output, quiet=False)
22
  # return output
23
 
24
- #model_file = "./seg_unet_model.h5"
25
- model_file = 'seg_unet_model.h5'
26
 
27
  #Custom objects for model
28
 
29
  def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
30
- targets = keras.ops.cast(K.flatten(y_true), dtype="float32")
31
- inputs = keras.ops.cast(K.flatten(y_pred), dtype="float32")
32
  intersection = K.sum(targets * inputs)
33
  dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
34
  inputs = K.clip(inputs, eps, 1.0 - eps)
@@ -38,20 +37,20 @@ def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
38
  return combo
39
 
40
  def dice_coef(y_true, y_pred, smooth=1):
41
- y_pred = keras.ops.cast(y_pred, dtype="int32")
42
- y_true = keras.ops.cast(y_true, tdtype="int32")
43
  intersection = K.sum(y_true * y_pred, axis=[1,2,3])
44
  union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
45
  return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
46
 
47
  def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
48
- pt_1 = keras.ops.where(keras.ops.equal(y_true, 1), y_pred, keras.ops.ones_like(y_pred))
49
- pt_0 = keras.ops.where(keras.ops.equal(y_true, 0), y_pred, keras.ops.ones_like(y_pred))
50
  focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
51
  return focal_loss_fixed
52
 
53
  # Load the model
54
- seg_model = keras.models.load_model(model_file, custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef})
55
 
56
  # inputs = gr.inputs.Image(type="pil", label="Upload an image")
57
  # image_output = gr.outputs.Image(type="pil", label="Output Image")
@@ -68,7 +67,7 @@ def gen_pred(img, model=seg_model):
68
  img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
69
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
70
  img = img/255
71
- img = keras.ops.expand_dims(img, axis=0)
72
  pred = model.predict(img)
73
  pred = np.squeeze(pred, axis=0)
74
  fig = plt.figure(figsize=(3, 3))
 
7
 
8
  from matplotlib import pyplot as plt
9
  from PIL import Image
10
+ from tensorflow import keras
11
 
12
 
13
  resized_shape = (768, 768, 3)
 
21
  # gdown.download(url, output, quiet=False)
22
  # return output
23
 
24
+ model_file = "./seg_unet_model.h5"
 
25
 
26
  #Custom objects for model
27
 
28
  def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
29
+ targets = tf.dtypes.cast(K.flatten(y_true), tf.float32)
30
+ inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32)
31
  intersection = K.sum(targets * inputs)
32
  dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
33
  inputs = K.clip(inputs, eps, 1.0 - eps)
 
37
  return combo
38
 
39
  def dice_coef(y_true, y_pred, smooth=1):
40
+ y_pred = tf.dtypes.cast(y_pred, tf.int32)
41
+ y_true = tf.dtypes.cast(y_true, tf.int32)
42
  intersection = K.sum(y_true * y_pred, axis=[1,2,3])
43
  union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
44
  return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
45
 
46
  def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
47
+ pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
48
+ pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
49
  focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
50
  return focal_loss_fixed
51
 
52
  # Load the model
53
+ seg_model = keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef})
54
 
55
  # inputs = gr.inputs.Image(type="pil", label="Upload an image")
56
  # image_output = gr.outputs.Image(type="pil", label="Output Image")
 
67
  img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
68
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
69
  img = img/255
70
+ img = tf.expand_dims(img, axis=0)
71
  pred = model.predict(img)
72
  pred = np.squeeze(pred, axis=0)
73
  fig = plt.figure(figsize=(3, 3))