Updates
Browse files- brlp_lite.py +29 -6
brlp_lite.py
CHANGED
@@ -32,10 +32,12 @@
|
|
32 |
# }
|
33 |
|
34 |
import os
|
|
|
35 |
from typing import Optional, Union
|
36 |
import pandas as pd
|
37 |
import argparse
|
38 |
import numpy as np
|
|
|
39 |
import warnings
|
40 |
import torch
|
41 |
import torch.nn as nn
|
@@ -43,7 +45,7 @@ from torch import Tensor
|
|
43 |
from torch.optim.optimizer import Optimizer
|
44 |
from torch.nn import L1Loss
|
45 |
from torch.utils.data import DataLoader
|
46 |
-
from torch.
|
47 |
from torch.amp import GradScaler
|
48 |
|
49 |
from generative.networks.nets import (
|
@@ -57,8 +59,12 @@ from monai import transforms
|
|
57 |
from monai.utils import set_determinism
|
58 |
from monai.data.meta_tensor import MetaTensor
|
59 |
import torch.serialization
|
60 |
-
|
|
|
|
|
61 |
torch.serialization.add_safe_globals([MetaTensor])
|
|
|
|
|
62 |
|
63 |
from tqdm import tqdm
|
64 |
import matplotlib.pyplot as plt
|
@@ -381,6 +387,22 @@ def train(
|
|
381 |
train_df = dataset_df[dataset_df.split == 'train']
|
382 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
train_loader = DataLoader(
|
385 |
dataset=trainset,
|
386 |
num_workers=num_workers,
|
@@ -440,13 +462,14 @@ def train(
|
|
440 |
total_counter = 0
|
441 |
|
442 |
for epoch in range(n_epochs):
|
|
|
443 |
autoencoder.train()
|
444 |
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
|
445 |
-
progress_bar.set_description(f'Epoch {epoch
|
446 |
|
447 |
for step, batch in progress_bar:
|
448 |
# Generator Training
|
449 |
-
with autocast(enabled=True):
|
450 |
images = batch["image"].to(device)
|
451 |
reconstruction, z_mu, z_sigma = autoencoder(images)
|
452 |
|
@@ -462,7 +485,7 @@ def train(
|
|
462 |
gradacc_g.step(loss_g, step)
|
463 |
|
464 |
# Discriminator Training
|
465 |
-
with autocast(enabled=True):
|
466 |
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
|
467 |
d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
|
468 |
logits_real = discriminator(images.contiguous().detach())[-1]
|
@@ -604,4 +627,4 @@ def main():
|
|
604 |
|
605 |
|
606 |
if __name__ == '__main__':
|
607 |
-
main()
|
|
|
32 |
# }
|
33 |
|
34 |
import os
|
35 |
+
os.environ["PYTORCH_WEIGHTS_ONLY"] = "False"
|
36 |
from typing import Optional, Union
|
37 |
import pandas as pd
|
38 |
import argparse
|
39 |
import numpy as np
|
40 |
+
|
41 |
import warnings
|
42 |
import torch
|
43 |
import torch.nn as nn
|
|
|
45 |
from torch.optim.optimizer import Optimizer
|
46 |
from torch.nn import L1Loss
|
47 |
from torch.utils.data import DataLoader
|
48 |
+
from torch.amp import autocast
|
49 |
from torch.amp import GradScaler
|
50 |
|
51 |
from generative.networks.nets import (
|
|
|
59 |
from monai.utils import set_determinism
|
60 |
from monai.data.meta_tensor import MetaTensor
|
61 |
import torch.serialization
|
62 |
+
from numpy.core.multiarray import _reconstruct
|
63 |
+
from numpy import ndarray, dtype
|
64 |
+
torch.serialization.add_safe_globals([_reconstruct])
|
65 |
torch.serialization.add_safe_globals([MetaTensor])
|
66 |
+
torch.serialization.add_safe_globals([ndarray])
|
67 |
+
torch.serialization.add_safe_globals([dtype])
|
68 |
|
69 |
from tqdm import tqdm
|
70 |
import matplotlib.pyplot as plt
|
|
|
387 |
train_df = dataset_df[dataset_df.split == 'train']
|
388 |
trainset = get_dataset_from_pd(train_df, transforms_fn, cache_dir)
|
389 |
|
390 |
+
print(f"[DEBUG] Using cache_dir={cache_dir}")
|
391 |
+
print(f"[DEBUG] trainset length={len(trainset)}")
|
392 |
+
|
393 |
+
try:
|
394 |
+
sample_debug = trainset[0] # Force a transform on the first record
|
395 |
+
print("[DEBUG] Successfully loaded sample 0 from trainset.")
|
396 |
+
except Exception as e:
|
397 |
+
print("[DEBUG] Error loading sample 0:", e)
|
398 |
+
|
399 |
+
import glob
|
400 |
+
|
401 |
+
hashfiles = glob.glob(os.path.join(cache_dir, "*.pt"))
|
402 |
+
print(f"[DEBUG] Found {len(hashfiles)} cached .pt files in {cache_dir}")
|
403 |
+
if hashfiles:
|
404 |
+
print("[DEBUG] Example cache file:", hashfiles[0])
|
405 |
+
|
406 |
train_loader = DataLoader(
|
407 |
dataset=trainset,
|
408 |
num_workers=num_workers,
|
|
|
462 |
total_counter = 0
|
463 |
|
464 |
for epoch in range(n_epochs):
|
465 |
+
print(f"[DEBUG] Starting epoch {epoch}/{n_epochs-1}")
|
466 |
autoencoder.train()
|
467 |
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
|
468 |
+
progress_bar.set_description(f'Epoch {epoch}')
|
469 |
|
470 |
for step, batch in progress_bar:
|
471 |
# Generator Training
|
472 |
+
with autocast(device, enabled=True):
|
473 |
images = batch["image"].to(device)
|
474 |
reconstruction, z_mu, z_sigma = autoencoder(images)
|
475 |
|
|
|
485 |
gradacc_g.step(loss_g, step)
|
486 |
|
487 |
# Discriminator Training
|
488 |
+
with autocast(device, enabled=True):
|
489 |
logits_fake = discriminator(reconstruction.contiguous().detach())[-1]
|
490 |
d_loss_fake = adv_loss_fn(logits_fake, target_is_real=False, for_discriminator=True)
|
491 |
logits_real = discriminator(images.contiguous().detach())[-1]
|
|
|
627 |
|
628 |
|
629 |
if __name__ == '__main__':
|
630 |
+
main()
|