Spaces:
Running
on
T4
Running
on
T4
Change dataset for LSdir
Browse files- datasets.py +24 -0
- evals.py +4 -7
datasets.py
CHANGED
@@ -4,6 +4,7 @@ import os
|
|
4 |
|
5 |
import torch
|
6 |
from torch.utils.data import Dataset
|
|
|
7 |
|
8 |
|
9 |
class Preprocessed_fastMRI(torch.utils.data.Dataset):
|
@@ -82,3 +83,26 @@ class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
|
|
82 |
img = self.transform(img)
|
83 |
|
84 |
img = img.unsqueeze(0) # add channel dim
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
import torch
|
6 |
from torch.utils.data import Dataset
|
7 |
+
from PIL import Image
|
8 |
|
9 |
|
10 |
class Preprocessed_fastMRI(torch.utils.data.Dataset):
|
|
|
83 |
img = self.transform(img)
|
84 |
|
85 |
img = img.unsqueeze(0) # add channel dim
|
86 |
+
return img
|
87 |
+
|
88 |
+
|
89 |
+
class LsdirMiniDataset(torch.utils.data.Dataset):
|
90 |
+
def __init__(
|
91 |
+
self,
|
92 |
+
root: str,
|
93 |
+
transform: Optional[Callable] = None,
|
94 |
+
) -> None:
|
95 |
+
self.root = root
|
96 |
+
self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith('.png')]
|
97 |
+
self.transform = transform
|
98 |
+
|
99 |
+
def __len__(self) -> int:
|
100 |
+
return len(self.image_files)
|
101 |
+
|
102 |
+
def __getitem__(self, idx):
|
103 |
+
img_path = os.path.join(self.root, self.image_files[idx])
|
104 |
+
img = Image.open(img_path)
|
105 |
+
if self.transform:
|
106 |
+
img = self.transform(img)
|
107 |
+
|
108 |
+
return img
|
evals.py
CHANGED
@@ -6,7 +6,7 @@ import torch
|
|
6 |
from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
|
7 |
from torchvision import transforms
|
8 |
|
9 |
-
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI
|
10 |
from utils import get_model
|
11 |
|
12 |
DEFAULT_MODEL_PARAMS = {
|
@@ -491,9 +491,7 @@ class BaselineModel(torch.nn.Module):
|
|
491 |
|
492 |
|
493 |
class EvalDataset(torch.utils.data.Dataset):
|
494 |
-
"""
|
495 |
-
We expect that images are 480x480.
|
496 |
-
"""
|
497 |
all_datasets = ["Natural", "MRI", "CT"]
|
498 |
|
499 |
def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
|
@@ -504,9 +502,8 @@ class EvalDataset(torch.utils.data.Dataset):
|
|
504 |
if self.name == 'Natural':
|
505 |
self.root = 'img_samples/LSDIR_samples'
|
506 |
self.transform = transforms.Compose([transforms.ToTensor()])
|
507 |
-
self.dataset =
|
508 |
-
|
509 |
-
transform=self.transform)
|
510 |
elif self.name == 'MRI':
|
511 |
self.root = 'img_samples/FastMRI_samples'
|
512 |
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
|
|
|
6 |
from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
|
7 |
from torchvision import transforms
|
8 |
|
9 |
+
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
|
10 |
from utils import get_model
|
11 |
|
12 |
DEFAULT_MODEL_PARAMS = {
|
|
|
491 |
|
492 |
|
493 |
class EvalDataset(torch.utils.data.Dataset):
|
494 |
+
""""""
|
|
|
|
|
495 |
all_datasets = ["Natural", "MRI", "CT"]
|
496 |
|
497 |
def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
|
|
|
502 |
if self.name == 'Natural':
|
503 |
self.root = 'img_samples/LSDIR_samples'
|
504 |
self.transform = transforms.Compose([transforms.ToTensor()])
|
505 |
+
self.dataset = LsdirMiniDataset(root=self.root,
|
506 |
+
transform=self.transform)
|
|
|
507 |
elif self.name == 'MRI':
|
508 |
self.root = 'img_samples/FastMRI_samples'
|
509 |
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
|