English
medical
brain-data
mri
jesseab commited on
Commit
dff7579
·
1 Parent(s): 8b07494
Files changed (1) hide show
  1. brlp_lite.py +29 -6
brlp_lite.py CHANGED
@@ -32,10 +32,12 @@
32
  # }
33
 
34
  import os
 
35
  from typing import Optional, Union
36
  import pandas as pd
37
  import argparse
38
  import numpy as np
 
39
  import warnings
40
  import torch
41
  import torch.nn as nn
@@ -43,7 +45,7 @@ from torch import Tensor
43
  from torch.optim.optimizer import Optimizer
44
  from torch.nn import L1Loss
45
  from torch.utils.data import DataLoader
46
- from torch.cuda.amp import autocast
47
  from torch.amp import GradScaler
48
 
49
  from generative.networks.nets import (
@@ -57,8 +59,12 @@ from monai import transforms
57
  from monai.utils import set_determinism
58
  from monai.data.meta_tensor import MetaTensor
59
  import torch.serialization
60
-
 
 
61
  torch.serialization.add_safe_globals([MetaTensor])
 
 
62
 
63
  from tqdm import tqdm
64
  import matplotlib.pyplot as plt
@@ -381,6 +387,22 @@ def train(
381
  train_df = dataset_df[dataset_df.split == 'train']
382
  trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  train_loader = DataLoader(
385
  dataset=trainset,
386
  num_workers=num_workers,
@@ -440,13 +462,14 @@ def train(
440
  total_counter = 0
441
 
442
  for epoch in range(n_epochs):
 
443
  autoencoder.train()
444
  progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
445
- progress_bar.set_description(f'Epoch {epoch + 1}/{n_epochs}')
446
 
447
  for step, batch in progress_bar:
448
  # Generator Training
449
- with autocast(enabled=True):
450
  images = batch["image"].to(device)
451
  reconstruction, z_mu, z_sigma = autoencoder(images)
452
 
@@ -462,7 +485,7 @@ def train(
462
  gradacc_g.step(loss_g, step)
463
 
464
  # Discriminator Training
465
- with autocast(enabled=True):
466
  logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
467
  d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
468
  logits_real = discriminator(images.contiguous().detach())[-1]
@@ -604,4 +627,4 @@ def main():
604
 
605
 
606
  if __name__ == '__main__':
607
- main()
 
32
  # }
33
 
34
  import os
35
+ os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
36
  from typing import Optional, Union
37
  import pandas as pd
38
  import argparse
39
  import numpy as np
40
+
41
  import warnings
42
  import torch
43
  import torch.nn as nn
 
45
  from torch.optim.optimizer import Optimizer
46
  from torch.nn import L1Loss
47
  from torch.utils.data import DataLoader
48
+ from torch.amp import autocast
49
  from torch.amp import GradScaler
50
 
51
  from generative.networks.nets import (
 
59
  from monai.utils import set_determinism
60
  from monai.data.meta_tensor import MetaTensor
61
  import torch.serialization
62
+ from numpy.core.multiarray import _reconstruct
63
+ from numpy import ndarray, dtype
64
+ torch.serialization.add_safe_globals([_reconstruct])
65
  torch.serialization.add_safe_globals([MetaTensor])
66
+ torch.serialization.add_safe_globals([ndarray])
67
+ torch.serialization.add_safe_globals([dtype])
68
 
69
  from tqdm import tqdm
70
  import matplotlib.pyplot as plt
 
387
  train_df = dataset_df[dataset_df.split == 'train']
388
  trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
389
 
390
+ print(f"[DEBUG] Using cache_dir={cache_dir}")
391
+ print(f"[DEBUG] trainset length={len(trainset)}")
392
+
393
+ try:
394
+ sample_debug = trainset[0] # Force a transform on the first record
395
+ print("[DEBUG] Successfully loaded sample 0 from trainset.")
396
+ except Exception as e:
397
+ print("[DEBUG] Error loading sample 0:", e)
398
+
399
+ import glob
400
+
401
+ hashfiles = glob.glob(os.path.join(cache_dir, "*.pt"))
402
+ print(f"[DEBUG] Found {len(hashfiles)} cached .pt files in {cache_dir}")
403
+ if hashfiles:
404
+ print("[DEBUG] Example cache file:", hashfiles[0])
405
+
406
  train_loader = DataLoader(
407
  dataset=trainset,
408
  num_workers=num_workers,
 
462
  total_counter = 0
463
 
464
  for epoch in range(n_epochs):
465
+ print(f"[DEBUG] Starting epoch {epoch}/{n_epochs-1}")
466
  autoencoder.train()
467
  progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
468
+ progress_bar.set_description(f'Epoch {epoch}')
469
 
470
  for step, batch in progress_bar:
471
  # Generator Training
472
+ with autocast(device, enabled=True):
473
  images = batch["image"].to(device)
474
  reconstruction, z_mu, z_sigma = autoencoder(images)
475
 
 
485
  gradacc_g.step(loss_g, step)
486
 
487
  # Discriminator Training
488
+ with autocast(device, enabled=True):
489
  logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
490
  d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
491
  logits_real = discriminator(images.contiguous().detach())[-1]
 
627
 
628
 
629
  if __name__ == '__main__':
630
+ main()