English
medical
brain-data
mri
jesseab commited on
Commit
8b07494
·
1 Parent(s): 3c0930a

Code fixes

Browse files
Files changed (1) hide show
  1. brlp_lite.py +8 -5
brlp_lite.py CHANGED
@@ -56,6 +56,9 @@ from monai.transforms.transform import Transform
56
  from monai import transforms
57
  from monai.utils import set_determinism
58
  from monai.data.meta_tensor import MetaTensor
 
 
 
59
 
60
  from tqdm import tqdm
61
  import matplotlib.pyplot as plt
@@ -93,9 +96,9 @@ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
93
  if checkpoints_path is not None:
94
  assert os.path.exists(checkpoints_path), 'Invalid path'
95
  # Using context manager to allow MetaTensor
96
- with torch.serialization.safe_globals([MetaTensor]):
97
- network.load_state_dict(torch.load(checkpoints_path))
98
- #network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
99
  return network
100
 
101
 
@@ -490,8 +493,8 @@ def train(
490
 
491
  # Save the model after each epoch.
492
  os.makedirs(output_dir, exist_ok=True)
493
- torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch + 1}.pth'))
494
- torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch + 1}.pth'))
495
 
496
  writer.close()
497
  print("Training completed and models saved.")
 
56
  from monai import transforms
57
  from monai.utils import set_determinism
58
  from monai.data.meta_tensor import MetaTensor
59
+ import torch.serialization
60
+
61
+ torch.serialization.add_safe_globals([MetaTensor])
62
 
63
  from tqdm import tqdm
64
  import matplotlib.pyplot as plt
 
96
  if checkpoints_path is not None:
97
  assert os.path.exists(checkpoints_path), 'Invalid path'
98
  # Using context manager to allow MetaTensor
99
+ #with torch.serialization.safe_globals([MetaTensor]):
100
+ network.load_state_dict(torch.load(checkpoints_path))
101
+ #network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
102
  return network
103
 
104
 
 
493
 
494
  # Save the model after each epoch.
495
  os.makedirs(output_dir, exist_ok=True)
496
+ torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch}.pth'))
497
+ torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch}.pth'))
498
 
499
  writer.close()
500
  print("Training completed and models saved.")