File size: 1,378 Bytes
73666ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch
from configs import *

def main():
    data_path = COMBINED_DATA_DIR + str(TASK)

    transform_img = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor(),  # Convert to tensor
            transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels
        ]
    )

    image_data = torchvision.datasets.ImageFolder(root=data_path, transform=transform_img)

    batch_size = BATCH_SIZE

    loader = DataLoader(image_data, batch_size=batch_size, num_workers=1)

    def batch_mean_and_sd(loader):
        cnt = 0
        fst_moment = torch.empty(3)
        snd_moment = torch.empty(3)

        for images, _ in loader:
            b, c, h, w = images.shape
            nb_pixels = b * h * w
            sum_ = torch.sum(images, dim=[0, 2, 3])
            sum_of_square = torch.sum(images**2, dim=[0, 2, 3])
            fst_moment = (cnt * fst_moment + sum_) / (cnt + nb_pixels)
            snd_moment = (cnt * snd_moment + sum_of_square) / (cnt + nb_pixels)
            cnt += nb_pixels

        mean, std = fst_moment, torch.sqrt(snd_moment - fst_moment**2)
        return mean, std

    mean, std = batch_mean_and_sd(loader)
    print("mean and std: \n", mean, std)

if __name__ == '__main__':
    main()