|
|
|
|
|
|
|
|
|
import os |
|
join = os.path.join |
|
import argparse |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import tifffile as tif |
|
import monai |
|
from tqdm import tqdm |
|
from utils.postprocess import mask_overlay |
|
from monai.transforms import Activations, AddChanneld, AsChannelFirstd, AsDiscrete, Compose, EnsureTyped, EnsureType |
|
from models.unicell_modules import MiT_B2_UNet_MultiHead, MiT_B3_UNet_MultiHead |
|
import matplotlib.pyplot as plt |
|
from skimage import io, exposure, segmentation, morphology |
|
from utils.postprocess import watershed_post |
|
from utils.multi_task_sliding_window_inference import multi_task_sliding_window_inference |
|
import gradio as gr |
|
|
|
def normalize_channel(img, lower=0.1, upper=99.9): |
|
non_zero_vals = img[np.nonzero(img)] |
|
percentiles = np.percentile(non_zero_vals, [lower, upper]) |
|
if percentiles[1] - percentiles[0] > 0.001: |
|
img_norm = exposure.rescale_intensity(img, in_range=(percentiles[0], percentiles[1]), out_range='uint8') |
|
else: |
|
img_norm = img |
|
return img_norm |
|
|
|
def preprocess(img_data): |
|
if len(img_data.shape) == 2: |
|
img_data = np.repeat(np.expand_dims(img_data, axis=-1), 3, axis=-1) |
|
elif len(img_data.shape) == 3 and img_data.shape[-1] > 3: |
|
img_data = img_data[:,:, :3] |
|
else: |
|
pass |
|
pre_img_data = np.zeros(img_data.shape, dtype=np.uint8) |
|
for i in range(3): |
|
img_channel_i = img_data[:,:,i] |
|
if len(img_channel_i[np.nonzero(img_channel_i)])>0: |
|
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=1, upper=99) |
|
return pre_img_data |
|
|
|
|
|
def inference(pre_img_data): |
|
test_npy = pre_img_data/np.max(pre_img_data) |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model = MiT_B2_UNet_MultiHead(in_channels=3, out_channels=3, regress_class=1, img_size=256).to(device) |
|
checkpoint = torch.load('./model.pth', map_location=torch.device(device)) |
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
model.eval() |
|
with torch.no_grad(): |
|
test_tensor = torch.from_numpy(np.expand_dims(test_npy, 0)).permute(0,3,1,2).type(torch.FloatTensor).to(device) |
|
|
|
val_pred, val_pred_dist = multi_task_sliding_window_inference(inputs=test_tensor, roi_size=(256, 256), sw_batch_size=8, predictor=model) |
|
|
|
|
|
val_seg_inst = watershed_post(val_pred_dist.squeeze(1).cpu().numpy(), val_pred.squeeze(1).cpu().numpy()[:,1]) |
|
test_pred_mask = val_seg_inst.squeeze().astype(np.uint16) |
|
|
|
|
|
boundary = segmentation.find_boundaries(test_pred_mask, connectivity=1, mode='inner') |
|
boundary = morphology.binary_dilation(boundary, morphology.disk(1)) |
|
pre_img_data[boundary, 0] = 0 |
|
pre_img_data[boundary, 1] = 255 |
|
pre_img_data[boundary, 2] = 0 |
|
|
|
return test_pred_mask, pre_img_data |
|
|
|
|
|
def predict(img): |
|
print('##########', img.name) |
|
img_name = img.name |
|
if img_name.endswith('.tif') or img_name.endswith('.tiff'): |
|
img_data = tif.imread(img_name) |
|
else: |
|
img_data = io.imread(img_name) |
|
if len(img_data.shape)==2: |
|
pre_img_data = normalize_channel(img_data, lower=0.1, upper=99.9) |
|
pre_img_data = np.repeat(np.expand_dims(pre_img_data, -1), repeats=3, axis=-1) |
|
else: |
|
pre_img_data = np.zeros((img_data.shape[0], img_data.shape[1], 3), dtype=np.uint8) |
|
for i in range(3): |
|
img_channel_i = img_data[:,:,i] |
|
if len(img_channel_i[np.nonzero(img_channel_i)])>0: |
|
pre_img_data[:,:,i] = normalize_channel(img_channel_i, lower=0.1, upper=99.9) |
|
|
|
seg_labels, seg_overlay = inference(pre_img_data) |
|
|
|
tif.imwrite(join(os.getcwd(), 'segmentation.tiff'), seg_labels, compression='zlib') |
|
|
|
return seg_overlay, join(os.getcwd(), 'segmentation.tiff') |
|
|
|
unicell_api = gr.Interface( |
|
predict, |
|
inputs = gr.File(label="Input image (png, bmp, jpg, tif, tiff)"), |
|
outputs = [gr.Image(label="Segmentation overlay"), gr.File(label="Download segmentation")], |
|
title = "UniCell Online Demo", |
|
examples=['demo.png', 'demo.tif'] |
|
) |
|
|
|
unicell_api.launch(share=True) |
|
|
|
|