lehduong commited on
Commit
819ebdf
·
verified ·
1 Parent(s): a767b8d

Delete dataset/transforms.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset/transforms.py +0 -133
dataset/transforms.py DELETED
@@ -1,133 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
-
4
- def crop(image, i, j, h, w):
5
- """
6
- Args:
7
- image (torch.tensor): Image to be cropped. Size is (C, H, W)
8
- """
9
- if len(image.size()) != 3:
10
- raise ValueError("image should be a 3D tensor")
11
- return image[..., i : i + h, j : j + w]
12
-
13
- def resize(image, target_size, interpolation_mode):
14
- if len(target_size) != 2:
15
- raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
16
- return F.interpolate(image.unsqueeze(0), size=target_size, mode=interpolation_mode, align_corners=False).squeeze(0)
17
-
18
- def resize_scale(image, target_size, interpolation_mode):
19
- if len(target_size) != 2:
20
- raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
21
- H, W = image.size(-2), image.size(-1)
22
- scale_ = target_size[0] / min(H, W)
23
- return F.interpolate(image.unsqueeze(0), scale_factor=scale_, mode=interpolation_mode, align_corners=False).squeeze(0)
24
-
25
- def resized_crop(image, i, j, h, w, size, interpolation_mode="bilinear"):
26
- """
27
- Do spatial cropping and resizing to the image
28
- Args:
29
- image (torch.tensor): Image to be cropped. Size is (C, H, W)
30
- i (int): i in (i,j) i.e coordinates of the upper left corner.
31
- j (int): j in (i,j) i.e coordinates of the upper left corner.
32
- h (int): Height of the cropped region.
33
- w (int): Width of the cropped region.
34
- size (tuple(int, int)): height and width of resized image
35
- Returns:
36
- image (torch.tensor): Resized and cropped image. Size is (C, H, W)
37
- """
38
- if len(image.size()) != 3:
39
- raise ValueError("image should be a 3D torch.tensor")
40
- image = crop(image, i, j, h, w)
41
- image = resize(image, size, interpolation_mode)
42
- return image
43
-
44
- def center_crop(image, crop_size):
45
- if len(image.size()) != 3:
46
- raise ValueError("image should be a 3D torch.tensor")
47
- h, w = image.size(-2), image.size(-1)
48
- th, tw = crop_size
49
- if h < th or w < tw:
50
- raise ValueError("height and width must be no smaller than crop_size")
51
- i = int(round((h - th) / 2.0))
52
- j = int(round((w - tw) / 2.0))
53
- return crop(image, i, j, th, tw)
54
-
55
- def center_crop_using_short_edge(image):
56
- if len(image.size()) != 3:
57
- raise ValueError("image should be a 3D torch.tensor")
58
- h, w = image.size(-2), image.size(-1)
59
- if h < w:
60
- th, tw = h, h
61
- i = 0
62
- j = int(round((w - tw) / 2.0))
63
- else:
64
- th, tw = w, w
65
- i = int(round((h - th) / 2.0))
66
- j = 0
67
- return crop(image, i, j, th, tw)
68
-
69
- class CenterCropResizeImage:
70
- """
71
- Resize the image while maintaining aspect ratio, and then crop it to the desired size.
72
- The resizing is done such that the area of padding/cropping is minimized.
73
- """
74
- def __init__(self, size, interpolation_mode="bilinear"):
75
- if isinstance(size, tuple):
76
- if len(size) != 2:
77
- raise ValueError(f"Size should be a tuple (height, width), instead got {size}")
78
- self.size = size
79
- else:
80
- self.size = (size, size)
81
- self.interpolation_mode = interpolation_mode
82
-
83
- def __call__(self, image):
84
- """
85
- Args:
86
- image (torch.Tensor): Image to be resized and cropped. Size is (C, H, W)
87
-
88
- Returns:
89
- torch.Tensor: Resized and cropped image. Size is (C, target_height, target_width)
90
- """
91
- target_height, target_width = self.size
92
- target_aspect = target_width / target_height
93
-
94
- # Get current image shape and aspect ratio
95
- _, height, width = image.shape
96
- height, width = float(height), float(width)
97
- current_aspect = width / height
98
-
99
- # Calculate crop dimensions
100
- if current_aspect > target_aspect:
101
- # Image is wider than target, crop width
102
- crop_height = height
103
- crop_width = height * target_aspect
104
- else:
105
- # Image is taller than target, crop height
106
- crop_height = width / target_aspect
107
- crop_width = width
108
-
109
- # Calculate crop coordinates (center crop)
110
- y1 = (height - crop_height) / 2
111
- x1 = (width - crop_width) / 2
112
-
113
- # Perform the crop
114
- cropped_image = crop(image, int(y1), int(x1), int(crop_height), int(crop_width))
115
-
116
- # Resize the cropped image to the target size
117
- resized_image = resize(cropped_image, self.size, self.interpolation_mode)
118
-
119
- return resized_image
120
-
121
- # Example usage
122
- if __name__ == "__main__":
123
- # Create a sample image tensor
124
- sample_image = torch.rand(3, 480, 640) # (C, H, W)
125
-
126
- # Initialize the transform
127
- transform = CenterCropResizeImage(size=(224, 224), interpolation_mode="bilinear")
128
-
129
- # Apply the transform
130
- transformed_image = transform(sample_image)
131
-
132
- print(f"Original image shape: {sample_image.shape}")
133
- print(f"Transformed image shape: {transformed_image.shape}")