from typing import Dict import numpy as np from einops import rearrange from monai.transforms.transform import Transform class OrientationGuidanceMultipleLabelDeepEditd(Transform): def __init__(self, ref_image="image", label_names=None): """ Convert the guidance to the RAS orientation """ self.ref_image = ref_image self.label_names = label_names def transform_points(self, point, affine): """transform point to the coordinates of the transformed image point: numpy array [bs, N, 3] """ bs, n = point.shape[:2] point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1) point = rearrange(point, "b n d -> d (b n)") point = affine @ point point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3] return point def __call__(self, data): d: Dict = dict(data) for key_label in self.label_names.keys(): points = d.get(key_label, []) if len(points) < 1: continue reoriented_points = self.transform_points( np.array(points)[None], np.linalg.inv(d[self.ref_image].meta["affine"].numpy()) @ d[self.ref_image].meta["original_affine"], ) d[key_label] = reoriented_points[0] return d