|
from typing import Dict |
|
|
|
import numpy as np |
|
from einops import rearrange |
|
from monai.transforms.transform import Transform |
|
|
|
|
|
class OrientationGuidanceMultipleLabelDeepEditd(Transform): |
|
def __init__(self, ref_image="image", label_names=None): |
|
""" |
|
Convert the guidance to the RAS orientation |
|
""" |
|
self.ref_image = ref_image |
|
self.label_names = label_names |
|
|
|
def transform_points(self, point, affine): |
|
"""transform point to the coordinates of the transformed image |
|
point: numpy array [bs, N, 3] |
|
""" |
|
bs, n = point.shape[:2] |
|
point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1) |
|
point = rearrange(point, "b n d -> d (b n)") |
|
point = affine @ point |
|
point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] |
|
return point |
|
|
|
def __call__(self, data): |
|
d: Dict = dict(data) |
|
for key_label in self.label_names.keys(): |
|
points = d.get(key_label, []) |
|
if len(points) < 1: |
|
continue |
|
reoriented_points = self.transform_points( |
|
np.array(points)[None], |
|
np.linalg.inv(d[self.ref_image].meta["affine"].numpy()) @ d[self.ref_image].meta["original_affine"], |
|
) |
|
d[key_label] = reoriented_points[0] |
|
return d |
|
|