File size: 1,052 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27

import torch 
import numpy as np 

class ToTensor16bit(object):
    """PyTorch can not handle uint16 only int16. First transform to int32. Note, this function also adds a channel-dim"""
    def __call__(self, image):
        # return torch.as_tensor(np.array(image, dtype=np.int32)[None]) 
        # return torch.from_numpy(np.array(image, np.int32, copy=True)[None])
        image = np.array(image, np.int32, copy=True) # [H,W,C] or [H,W]
        image = np.expand_dims(image, axis=-1) if image.ndim ==2 else image 
        return torch.from_numpy(np.moveaxis(image, -1, 0)) #[C, H, W]

class Normalize(object):
    """Rescale the image to [0,1] and ensure float32 dtype """

    def __call__(self, image):
        image = image.type(torch.FloatTensor)
        return (image-image.min())/(image.max()-image.min())


class RandomBackground(object):
    """Fill Background (intensity ==0) with random values"""

    def __call__(self, image):
        image[image==0] = torch.rand(*image[image==0].shape) #(image.max()-image.min())
        return image