|
|
|
import sys |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from transformers import PreTrainedModel |
|
|
|
from .ProbUNet_model import InjectionConvEncoder2D, InjectionUNet2D, InjectionConvEncoder3D, InjectionUNet3D, ProbabilisticSegmentationNet |
|
from .PULASkiConfigs import ProbUNetConfig |
|
|
|
class ProbUNet(PreTrainedModel): |
|
config_class = ProbUNetConfig |
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
if config.dim == 2: |
|
task_op = InjectionUNet2D |
|
prior_op = InjectionConvEncoder2D |
|
posterior_op = InjectionConvEncoder2D |
|
elif config.dim == 3: |
|
task_op = InjectionUNet3D |
|
prior_op = InjectionConvEncoder3D |
|
posterior_op = InjectionConvEncoder3D |
|
else: |
|
sys.exit("Invalid dim! Only configured for dim 2 and 3.") |
|
|
|
if config.latent_distribution == "normal": |
|
latent_distribution = torch.distributions.Normal |
|
else: |
|
sys.exit("Invalid latent_distribution. Only normal has been implemented.") |
|
|
|
self.model = ProbabilisticSegmentationNet(in_channels=config.in_channels, |
|
out_channels=config.out_channels, |
|
num_feature_maps=config.num_feature_maps, |
|
latent_size=config.latent_size, |
|
depth=config.depth, |
|
latent_distribution=latent_distribution, |
|
task_op=task_op, |
|
task_kwargs={"output_activation_op": nn.Identity if config.no_outact_op else nn.Sigmoid, |
|
"activation_kwargs": {"inplace": True}, "injection_at": config.prob_injection_at}, |
|
prior_op=prior_op, |
|
prior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, |
|
posterior_op=posterior_op, |
|
posterior_kwargs={"activation_kwargs": {"inplace": True}, "norm_depth": 2}, |
|
) |
|
def forward(self, x): |
|
return self.model(x) |