msong97 commited on
Commit
ec1223e
·
1 Parent(s): bc0aa28

Change dataset for LSdir

Browse files
Files changed (2) hide show
  1. datasets.py +24 -0
  2. evals.py +4 -7
datasets.py CHANGED
@@ -4,6 +4,7 @@ import os
4
 
5
  import torch
6
  from torch.utils.data import Dataset
 
7
 
8
 
9
  class Preprocessed_fastMRI(torch.utils.data.Dataset):
@@ -82,3 +83,26 @@ class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
82
  img = self.transform(img)
83
 
84
  img = img.unsqueeze(0) # add channel dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import torch
6
  from torch.utils.data import Dataset
7
+ from PIL import Image
8
 
9
 
10
  class Preprocessed_fastMRI(torch.utils.data.Dataset):
 
83
  img = self.transform(img)
84
 
85
  img = img.unsqueeze(0) # add channel dim
86
+ return img
87
+
88
+
89
+ class LsdirMiniDataset(torch.utils.data.Dataset):
90
+ def __init__(
91
+ self,
92
+ root: str,
93
+ transform: Optional[Callable] = None,
94
+ ) -> None:
95
+ self.root = root
96
+ self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith('.png')]
97
+ self.transform = transform
98
+
99
+ def __len__(self) -> int:
100
+ return len(self.image_files)
101
+
102
+ def __getitem__(self, idx):
103
+ img_path = os.path.join(self.root, self.image_files[idx])
104
+ img = Image.open(img_path)
105
+ if self.transform:
106
+ img = self.transform(img)
107
+
108
+ return img
evals.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
7
  from torchvision import transforms
8
 
9
- from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI
10
  from utils import get_model
11
 
12
  DEFAULT_MODEL_PARAMS = {
@@ -491,9 +491,7 @@ class BaselineModel(torch.nn.Module):
491
 
492
 
493
  class EvalDataset(torch.utils.data.Dataset):
494
- """
495
- We expect that images are 480x480.
496
- """
497
  all_datasets = ["Natural", "MRI", "CT"]
498
 
499
  def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
@@ -504,9 +502,8 @@ class EvalDataset(torch.utils.data.Dataset):
504
  if self.name == 'Natural':
505
  self.root = 'img_samples/LSDIR_samples'
506
  self.transform = transforms.Compose([transforms.ToTensor()])
507
- self.dataset = dinv.datasets.LsdirHR(root=self.root,
508
- download=False,
509
- transform=self.transform)
510
  elif self.name == 'MRI':
511
  self.root = 'img_samples/FastMRI_samples'
512
  self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
 
6
  from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
7
  from torchvision import transforms
8
 
9
+ from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
10
  from utils import get_model
11
 
12
  DEFAULT_MODEL_PARAMS = {
 
491
 
492
 
493
  class EvalDataset(torch.utils.data.Dataset):
494
+ """"""
 
 
495
  all_datasets = ["Natural", "MRI", "CT"]
496
 
497
  def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
 
502
  if self.name == 'Natural':
503
  self.root = 'img_samples/LSDIR_samples'
504
  self.transform = transforms.Compose([transforms.ToTensor()])
505
+ self.dataset = LsdirMiniDataset(root=self.root,
506
+ transform=self.transform)
 
507
  elif self.name == 'MRI':
508
  self.root = 'img_samples/FastMRI_samples'
509
  self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)