File size: 3,136 Bytes
4dc3e99
 
 
 
 
 
9aad168
4dc3e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9aad168
 
 
 
 
 
 
 
 
 
3a575e4
9aad168
 
 
 
 
 
 
5b3128d
9aad168
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from pathlib import Path
from typing import Callable, Optional
import os

import torch
from torch.utils.data import Dataset
from PIL import Image


class Preprocessed_fastMRI(torch.utils.data.Dataset):
    """FastMRI from preprocessed data for faster lading."""

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        preprocess: bool = False,
    ) -> None:
        self.root = root
        self.transform = transform
        self.preprocess = preprocess

        # should contain all the information to load a data sample from the storage
        self.sample_identifiers = []

        # append all filenames in self.root ending with .pt
        for root, _, files in os.walk(self.root):
            for file in files:
                if file.endswith(".pt"):
                    self.sample_identifiers.append(file)

    def __len__(self) -> int:
        return len(self.sample_identifiers)

    def __getitem__(self, idx: int):
        fname = self.sample_identifiers[idx]

        tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
        img = tensor['data'].float()

        if self.transform is not None:
            img = self.transform(img)

        if not self.preprocess:
            return img

        else:
            # remove extension and prefix from filename
            fname = Path(fname).stem
            return img, fname


class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
    """FastMRI from preprocessed data for faster lading."""

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
    ) -> None:
        self.root = root
        self.transform = transform

        # should contain all the information to load a data sample from the storage
        self.sample_identifiers = []

        # append all filenames in self.root ending with .pt
        for root, _, files in os.walk(self.root):
            for file in files:
                if file.endswith(".pt"):
                    self.sample_identifiers.append(file)

    def __len__(self) -> int:
        return len(self.sample_identifiers)

    def __getitem__(self, idx: int):
        fname = self.sample_identifiers[idx]

        tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
        img = tensor['data'].float()

        if self.transform is not None:
            img = self.transform(img)

        img = img.unsqueeze(0)  # add channel dim
        return img


class LsdirMiniDataset(torch.utils.data.Dataset):
    def __init__(
	self,
	root: str,
	transform: Optional[Callable] = None,
    ) -> None:
        self.root = root
        self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith(('.png', '.jpeg'))]
        self.transform = transform

    def __len__(self) -> int:
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.image_files[idx])
        img = Image.open(img_path).convert("RGB")  # Ensure consistent 3-channel format
        if self.transform:
            img = self.transform(img)

        return img