fatchecker / predict_unet.py
bumble-bee's picture
Add unet versions(v0, v1, v2, v3)
3368fe8
raw
history blame
2.34 kB
import os
import numpy as np
import skimage.io as skio
import skimage.transform as trans
from skimage.color import rgb2gray
from matplotlib import pyplot as plt
import sys
sys.path.append("/panfs/jay/groups/29/umii/mo000007/zooniverse/UNet")
from utils import *
from unet import unet
from unet_3plus import UNet_3Plus, UNet_3Plus_DeepSup, UNet_3Plus_DeepSup_CGM
def predict_model(input, unet_type):
model_path = "/home/umii/mo000007/zooniverse/UNet/trained_models"
h, w = 256, 256
input_shape = [h, w, 1]
output_channels = 1
batch_size = 1
# convert image into numpy array and reshape it into model's input size
img = trans.resize(input, (w, h))
result_img = img.copy()
img = rgb2gray(img).reshape(1, h, w, 1)
# Load U-net model based on version: UNet type vo:unet, v1:unet3+, v2:unet3+ with deep supervision, v3:unet3+ with cgm
if unet_type == 'v0':
# load the pretrained model
model_name = "unetv0_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = unet(model_file)
elif unet_type == 'v1':
# load the pretrained model
model_name = "unetv1_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = UNet_3Plus(input_shape, output_channels, model_file)
elif unet_type == 'v2':
# load the pretrained model
model_name = "unetv2_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = UNet_3Plus_DeepSup(input_shape, output_channels, model_file)
else:
# load the pretrained model
model_name = "unetv3_sgd500_neptune"
model_file = os.path.join(model_path, model_name + ".hdf5")
model = UNet_3Plus_DeepSup_CGM(input_shape, output_channels, model_file)
# Predict and save the results as numpy array
results = model.predict(img)
# Preprocess the prediction from the model depending on the model
if unet_type == 'v2' or unet_type == 'v3':
pred = np.copy(results[0])
else:
pred = np.copy(results)
pred[pred >= 0.5] = 1
pred[pred < 0.5] = 0
output = np.array(pred[0][:,:,0])
# visualize the output mask
seg_color = [0, 0, 255]
masked = output != 0
result_img[masked] = seg_color
return result_img