{ "imports": [ "$import torch", "$from datetime import datetime", "$from pathlib import Path" ], "bundle_root": ".", "dataset_dir": "", "dataset": "", "evaluator": "", "inferer": "", "load_old": 1, "model_dir": "$@bundle_root + '/models'", "output_dir": "$@bundle_root + '/output'", "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "gender": 0.0, "age": 0.1, "ventricular_vol": 0.2, "brain_vol": 0.4, "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "conditioning": "$torch.tensor([[@gender, @age, @ventricular_vol, @brain_vol]]).to(@device).unsqueeze(1)", "out_file": "$datetime.now().strftime('sample_%H%M%S_%d%m%Y') + '_' + str(@gender) + '_' + str(@age) + '_' + str(@ventricular_vol) + '_' + str(@brain_vol)", "autoencoder_def": { "_target_": "monai.networks.nets.AutoencoderKL", "spatial_dims": 3, "in_channels": 1, "out_channels": 1, "latent_channels": 3, "channels": [ 64, 128, 128, 128 ], "num_res_blocks": 2, "norm_num_groups": 32, "norm_eps": 1e-06, "attention_levels": [ false, false, false, false ], "with_encoder_nonlocal_attn": false, "with_decoder_nonlocal_attn": false }, "network_def": "@autoencoder_def", "load_autoencoder_path": "$@model_dir + '/autoencoder.pt'", "load_autoencoder_func": "$@autoencoder_def.load_old_state_dict if bool(@load_old) else @autoencoder_def.load_state_dict", "load_autoencoder": "$@load_autoencoder_func(torch.load(@load_autoencoder_path))", "autoencoder": "$@autoencoder_def.to(@device)", "diffusion_def": { "_target_": "monai.networks.nets.DiffusionModelUNet", "spatial_dims": 3, "in_channels": 7, "out_channels": 3, "channels": [ 256, 512, 768 ], "num_res_blocks": 2, "attention_levels": [ false, true, true ], "norm_num_groups": 32, "norm_eps": 1e-06, "resblock_updown": true, "num_head_channels": [ 0, 512, 768 ], "with_conditioning": true, "transformer_num_layers": 1, "cross_attention_dim": 4, "upcast_attention": true, "use_flash_attention": false }, "load_diffusion_path": "$@model_dir + '/model.pt'", "load_diffusion_func": "$@diffusion_def.load_old_state_dict if bool(@load_old) else @diffusion_def.load_state_dict", "load_diffusion": "$@load_diffusion_func(torch.load(@load_diffusion_path))", "diffusion": "$@diffusion_def.to(@device)", "scheduler": { "_target_": "monai.networks.schedulers.DDIMScheduler", "_requires_": [ "@load_diffusion", "@load_autoencoder" ], "beta_start": 0.0015, "beta_end": 0.0205, "num_train_timesteps": 1000, "schedule": "scaled_linear_beta", "clip_sample": false }, "noise": "$torch.randn((1, 3, 20, 28, 20)).to(@device)", "set_timesteps": "$@scheduler.set_timesteps(num_inference_steps=50)", "sampler": { "_target_": "scripts.sampler.Sampler", "_requires_": "@set_timesteps" }, "sample": "$@sampler.sampling_fn(@noise, @autoencoder, @diffusion, @scheduler, @conditioning)", "saver": { "_target_": "SaveImage", "_requires_": "@create_output_dir", "output_dir": "@output_dir", "output_postfix": "@out_file" }, "run": "$@saver(@sample[0][0])" }