Code fixes
Browse files- brlp_lite.py +8 -5
brlp_lite.py
CHANGED
@@ -56,6 +56,9 @@ from monai.transforms.transform import Transform
|
|
56 |
from monai import transforms
|
57 |
from monai.utils import set_determinism
|
58 |
from monai.data.meta_tensor import MetaTensor
|
|
|
|
|
|
|
59 |
|
60 |
from tqdm import tqdm
|
61 |
import matplotlib.pyplot as plt
|
@@ -93,9 +96,9 @@ def load_if(checkpoints_path: Optional[str], network: nn.Module) -> nn.Module:
|
|
93 |
if checkpoints_path is not None:
|
94 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
95 |
# Using context manager to allow MetaTensor
|
96 |
-
with torch.serialization.safe_globals([MetaTensor]):
|
97 |
-
|
98 |
-
|
99 |
return network
|
100 |
|
101 |
|
@@ -490,8 +493,8 @@ def train(
|
|
490 |
|
491 |
# Save the model after each epoch.
|
492 |
os.makedirs(output_dir, exist_ok=True)
|
493 |
-
torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch
|
494 |
-
torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch
|
495 |
|
496 |
writer.close()
|
497 |
print("Training completed and models saved.")
|
|
|
56 |
from monai import transforms
|
57 |
from monai.utils import set_determinism
|
58 |
from monai.data.meta_tensor import MetaTensor
|
59 |
+
import torch.serialization
|
60 |
+
|
61 |
+
torch.serialization.add_safe_globals([MetaTensor])
|
62 |
|
63 |
from tqdm import tqdm
|
64 |
import matplotlib.pyplot as plt
|
|
|
96 |
if checkpoints_path is not None:
|
97 |
assert os.path.exists(checkpoints_path), 'Invalid path'
|
98 |
# Using context manager to allow MetaTensor
|
99 |
+
#with torch.serialization.safe_globals([MetaTensor]):
|
100 |
+
network.load_state_dict(torch.load(checkpoints_path))
|
101 |
+
#network.load_state_dict(torch.load(checkpoints_path, map_location='cpu'))
|
102 |
return network
|
103 |
|
104 |
|
|
|
493 |
|
494 |
# Save the model after each epoch.
|
495 |
os.makedirs(output_dir, exist_ok=True)
|
496 |
+
torch.save(discriminator.state_dict(), os.path.join(output_dir, f'discriminator-ep-{epoch}.pth'))
|
497 |
+
torch.save(autoencoder.state_dict(), os.path.join(output_dir, f'autoencoder-ep-{epoch}.pth'))
|
498 |
|
499 |
writer.close()
|
500 |
print("Training completed and models saved.")
|