|
import numpy as np |
|
import torch |
|
from medpy import metric |
|
from scipy.ndimage import zoom |
|
import torch.nn as nn |
|
import SimpleITK as sitk |
|
|
|
|
|
class DiceLoss(nn.Module): |
|
def __init__(self, n_classes): |
|
super(DiceLoss, self).__init__() |
|
self.n_classes = n_classes |
|
|
|
def _one_hot_encoder(self, input_tensor): |
|
tensor_list = [] |
|
for i in range(self.n_classes): |
|
temp_prob = input_tensor == i |
|
tensor_list.append(temp_prob.unsqueeze(1)) |
|
output_tensor = torch.cat(tensor_list, dim=1) |
|
return output_tensor.float() |
|
|
|
def _dice_loss(self, score, target): |
|
target = target.float() |
|
smooth = 1e-5 |
|
intersect = torch.sum(score * target) |
|
y_sum = torch.sum(target * target) |
|
z_sum = torch.sum(score * score) |
|
loss = (2 * intersect + smooth) / (z_sum + y_sum + smooth) |
|
loss = 1 - loss |
|
return loss |
|
|
|
def forward(self, inputs, target, weight=None, softmax=False): |
|
if softmax: |
|
inputs = torch.softmax(inputs, dim=1) |
|
target = self._one_hot_encoder(target) |
|
if weight is None: |
|
weight = [1] * self.n_classes |
|
assert inputs.size() == target.size(), 'predict {} & target {} shape do not match'.format(inputs.size(), target.size()) |
|
class_wise_dice = [] |
|
loss = 0.0 |
|
for i in range(0, self.n_classes): |
|
dice = self._dice_loss(inputs[:, i], target[:, i]) |
|
class_wise_dice.append(1.0 - dice.item()) |
|
loss += dice * weight[i] |
|
return loss / self.n_classes |
|
|
|
|
|
def calculate_metric_percase(pred, gt): |
|
pred[pred > 0] = 1 |
|
gt[gt > 0] = 1 |
|
if pred.sum() > 0 and gt.sum()>0: |
|
dice = metric.binary.dc(pred, gt) |
|
hd95 = metric.binary.hd95(pred, gt) |
|
return dice, hd95 |
|
elif pred.sum() > 0 and gt.sum()==0: |
|
return 1, 0 |
|
else: |
|
return 0, 0 |
|
|
|
|
|
def test_single_volume(image, label, net, classes, patch_size=[256, 256], test_save_path=None, case=None, z_spacing=1): |
|
image, label = image.squeeze(0).cpu().detach().numpy(), label.squeeze(0).cpu().detach().numpy() |
|
if len(image.shape) == 3: |
|
prediction = np.zeros_like(label) |
|
for ind in range(image.shape[0]): |
|
slice = image[ind, :, :] |
|
x, y = slice.shape[0], slice.shape[1] |
|
if x != patch_size[0] or y != patch_size[1]: |
|
slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) |
|
input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda() |
|
net.eval() |
|
with torch.no_grad(): |
|
outputs = net(input) |
|
out = torch.argmax(torch.softmax(outputs, dim=1), dim=1).squeeze(0) |
|
out = out.cpu().detach().numpy() |
|
if x != patch_size[0] or y != patch_size[1]: |
|
pred = zoom(out, (x / patch_size[0], y / patch_size[1]), order=0) |
|
else: |
|
pred = out |
|
prediction[ind] = pred |
|
else: |
|
input = torch.from_numpy(image).unsqueeze( |
|
0).unsqueeze(0).float().cuda() |
|
net.eval() |
|
with torch.no_grad(): |
|
out = torch.argmax(torch.softmax(net(input), dim=1), dim=1).squeeze(0) |
|
prediction = out.cpu().detach().numpy() |
|
metric_list = [] |
|
for i in range(1, classes): |
|
metric_list.append(calculate_metric_percase(prediction == i, label == i)) |
|
|
|
if test_save_path is not None: |
|
img_itk = sitk.GetImageFromArray(image.astype(np.float32)) |
|
prd_itk = sitk.GetImageFromArray(prediction.astype(np.float32)) |
|
lab_itk = sitk.GetImageFromArray(label.astype(np.float32)) |
|
img_itk.SetSpacing((1, 1, z_spacing)) |
|
prd_itk.SetSpacing((1, 1, z_spacing)) |
|
lab_itk.SetSpacing((1, 1, z_spacing)) |
|
sitk.WriteImage(prd_itk, test_save_path + '/'+case + "_pred.nii.gz") |
|
sitk.WriteImage(img_itk, test_save_path + '/'+ case + "_img.nii.gz") |
|
sitk.WriteImage(lab_itk, test_save_path + '/'+ case + "_gt.nii.gz") |
|
return metric_list |