Spaces:
Runtime error
Runtime error
from email.mime import audio | |
from pathlib import Path | |
from datetime import datetime | |
import torch | |
import torch.nn as nn | |
from pytorch_lightning.trainer import Trainer | |
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint | |
import numpy as np | |
import torchio as tio | |
from medical_diffusion.data.datamodules import SimpleDataModule | |
from medical_diffusion.data.datasets import AIROGSDataset, MSIvsMSS_2_Dataset, CheXpert_2_Dataset | |
from medical_diffusion.models.pipelines import DiffusionPipeline | |
from medical_diffusion.models.estimators import UNet | |
from medical_diffusion.external.stable_diffusion.unet_openai import UNetModel | |
from medical_diffusion.models.noise_schedulers import GaussianNoiseScheduler | |
from medical_diffusion.models.embedders import LabelEmbedder, TimeEmbbeding | |
from medical_diffusion.models.embedders.latent_embedders import VAE, VAEGAN, VQVAE, VQGAN | |
import torch.multiprocessing | |
torch.multiprocessing.set_sharing_strategy('file_system') | |
if __name__ == "__main__": | |
# ------------ Load Data ---------------- | |
# ds = AIROGSDataset( | |
# crawler_ext='jpg', | |
# augment_horizontal_flip = False, | |
# augment_vertical_flip = False, | |
# # path_root='/home/gustav/Documents/datasets/AIROGS/data_256x256/', | |
# path_root='/mnt/hdd/datasets/eye/AIROGS/data_256x256', | |
# ) | |
# ds = MSIvsMSS_2_Dataset( | |
# crawler_ext='jpg', | |
# image_resize=None, | |
# image_crop=None, | |
# augment_horizontal_flip=False, | |
# augment_vertical_flip=False, | |
# # path_root='/home/gustav/Documents/datasets/Kather_2/train', | |
# path_root='/mnt/hdd/datasets/pathology/kather_msi_mss_2/train/', | |
# ) | |
ds = CheXpert_2_Dataset( # 256x256 | |
augment_horizontal_flip=False, | |
augment_vertical_flip=False, | |
path_root = '/mnt/hdd/datasets/chest/CheXpert/ChecXpert-v10/preprocessed_tianyu' | |
) | |
dm = SimpleDataModule( | |
ds_train = ds, | |
batch_size=32, | |
# num_workers=0, | |
pin_memory=True, | |
# weights=ds.get_weights() | |
) | |
current_time = datetime.now().strftime("%Y_%m_%d_%H%M%S") | |
path_run_dir = Path.cwd() / 'runs' / str(current_time) | |
path_run_dir.mkdir(parents=True, exist_ok=True) | |
accelerator = 'gpu' if torch.cuda.is_available() else 'cpu' | |
# ------------ Initialize Model ------------ | |
# cond_embedder = None | |
cond_embedder = LabelEmbedder | |
cond_embedder_kwargs = { | |
'emb_dim': 1024, | |
'num_classes': 2 | |
} | |
time_embedder = TimeEmbbeding | |
time_embedder_kwargs ={ | |
'emb_dim': 1024 # stable diffusion uses 4*model_channels (model_channels is about 256) | |
} | |
noise_estimator = UNet | |
noise_estimator_kwargs = { | |
'in_ch':8, | |
'out_ch':8, | |
'spatial_dims':2, | |
'hid_chs': [ 256, 256, 512, 1024], | |
'kernel_sizes':[3, 3, 3, 3], | |
'strides': [1, 2, 2, 2], | |
'time_embedder':time_embedder, | |
'time_embedder_kwargs': time_embedder_kwargs, | |
'cond_embedder':cond_embedder, | |
'cond_embedder_kwargs': cond_embedder_kwargs, | |
'deep_supervision': False, | |
'use_res_block':True, | |
'use_attention':'none', | |
} | |
# ------------ Initialize Noise ------------ | |
noise_scheduler = GaussianNoiseScheduler | |
noise_scheduler_kwargs = { | |
'timesteps': 1000, | |
'beta_start': 0.002, # 0.0001, 0.0015 | |
'beta_end': 0.02, # 0.01, 0.0195 | |
'schedule_strategy': 'scaled_linear' | |
} | |
# ------------ Initialize Latent Space ------------ | |
# latent_embedder = None | |
# latent_embedder = VQVAE | |
latent_embedder = VAE | |
latent_embedder_checkpoint = 'runs/2022_12_12_133315_chest_vaegan/last_vae.ckpt' | |
# ------------ Initialize Pipeline ------------ | |
pipeline = DiffusionPipeline( | |
noise_estimator=noise_estimator, | |
noise_estimator_kwargs=noise_estimator_kwargs, | |
noise_scheduler=noise_scheduler, | |
noise_scheduler_kwargs = noise_scheduler_kwargs, | |
latent_embedder=latent_embedder, | |
latent_embedder_checkpoint = latent_embedder_checkpoint, | |
estimator_objective='x_T', | |
estimate_variance=False, | |
use_self_conditioning=False, | |
use_ema=False, | |
classifier_free_guidance_dropout=0.5, # Disable during training by setting to 0 | |
do_input_centering=False, | |
clip_x0=False, | |
sample_every_n_steps=1000 | |
) | |
# pipeline_old = pipeline.load_from_checkpoint('runs/2022_11_27_085654_chest_diffusion/last.ckpt') | |
# pipeline.noise_estimator.load_state_dict(pipeline_old.noise_estimator.state_dict(), strict=True) | |
# -------------- Training Initialization --------------- | |
to_monitor = "train/loss" # "pl/val_loss" | |
min_max = "min" | |
save_and_sample_every = 100 | |
early_stopping = EarlyStopping( | |
monitor=to_monitor, | |
min_delta=0.0, # minimum change in the monitored quantity to qualify as an improvement | |
patience=30, # number of checks with no improvement | |
mode=min_max | |
) | |
checkpointing = ModelCheckpoint( | |
dirpath=str(path_run_dir), # dirpath | |
monitor=to_monitor, | |
every_n_train_steps=save_and_sample_every, | |
save_last=True, | |
save_top_k=2, | |
mode=min_max, | |
) | |
trainer = Trainer( | |
accelerator=accelerator, | |
# devices=[0], | |
# precision=16, | |
# amp_backend='apex', | |
# amp_level='O2', | |
# gradient_clip_val=0.5, | |
default_root_dir=str(path_run_dir), | |
callbacks=[checkpointing], | |
# callbacks=[checkpointing, early_stopping], | |
enable_checkpointing=True, | |
check_val_every_n_epoch=1, | |
log_every_n_steps=save_and_sample_every, | |
auto_lr_find=False, | |
# limit_train_batches=1000, | |
limit_val_batches=0, # 0 = disable validation - Note: Early Stopping no longer available | |
min_epochs=100, | |
max_epochs=1001, | |
num_sanity_val_steps=2, | |
) | |
# ---------------- Execute Training ---------------- | |
trainer.fit(pipeline, datamodule=dm) | |
# ------------- Save path to best model ------------- | |
pipeline.save_best_checkpoint(trainer.logger.log_dir, checkpointing.best_model_path) | |