File size: 9,560 Bytes
a57c6eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
from einops import rearrange
import requests
from io import BytesIO
from typing import Literal, Union
import math

from PIL import Image

import numpy as np
from diffusers.utils import load_image
import cv2
import torch

from mmcm.vision.utils.data_type_util import convert_images
from transformers.models.clip.image_processing_clip import to_numpy_array
from ..utils.vision_util import round_up_to_even


def get_image_from_input(image: Union[str, Image.Image]) -> Image.Image:
    if isinstance(image, str):
        if "http" in image:
            image = BytesIO(requests.get(image).content)
            image = Image.open(image).convert("RGB")
        else:
            image = Image.open(image).convert("RGB")
    else:
        image = image.convert("RGB")
    assert type(image) == Image.Image
    return image


def dynamic_resize_image(
    image: Image.Image,
    target_height: int,
    target_width: int,
    image_max_length: int = 910,
) -> Image.Image:
    """对图像进行预处理,目前会将短边resize到目标长度,同时限制长边长度

    Args:
        image (Image.Image): _description_
        target_height (int): _description_
        target_width (int): _description_
        image_max_length (int): _description_

    Returns:
        Image.Image: _description_
    """
    w, h = image.size
    if w > h:
        target_width = min(math.ceil(w * target_height / h), image_max_length)
        target_height = math.ceil(target_width / w * h)
    else:
        target_height = min(math.ceil(h * target_width / w), image_max_length)
        target_width = math.ceil(target_height / h * w)
    target_width = round_up_to_even(target_width)
    target_height = round_up_to_even(target_height)
    image = image.resize((target_width, target_height))
    return image


def dynamic_crop_resize_image(
    image: Image.Image,
    target_height: int,
    target_width: int,
    resample=None,
) -> Image.Image:
    """获取图像有效部分,并resize到对应目标宽度和高度。
        如果图像宽高比大于 target_width / target_height,则保留全部高,截取宽的中心部位;
        如果图像宽高比小于 target_width / target_height,则保留全部宽,截取高的中心部位;
        最后,将截取的图像resize到目标宽高
    Args:
        image (Image.Image): 输入图像
        target_height (int): 目标高
        target_width (int): 目标宽

    Returns:
        Image.Image: 动态截取、resize生成的图像
    """
    w, h = image.size
    image_width_heigt_ratio = w / h
    target_width_height_ratio = target_width / target_height
    if image_width_heigt_ratio >= target_width_height_ratio:
        y1 = 0
        y2 = h - 1
        x1 = math.ceil((w - h * target_width / target_height) / 2)
        x2 = math.ceil(w - (w - h * target_width / target_height) / 2)
    else:
        x1 = 0
        x2 = w - 1
        y1 = math.ceil((h - w * target_height / target_width) / 2)
        y2 = math.ceil(h - (h - w * target_height / target_width) / 2)
    x1 = max(0, x1)
    x2 = min(x2, w - 1)
    y1 = max(0, y1)
    y2 = min(y2, h - 1)
    image = image.crop((x1, y1, x2, y2))
    image = image.resize((target_width, target_height), resample=resample)
    return image


def get_canny(
    image: np.ndarray, low_threshold: float, high_threshold: float
) -> np.ndarray:
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return image


def pad_matrix(matrix, target_shape):
    h, w, c = matrix.shape
    h1, w1 = target_shape

    if h1 < h or w1 < w:
        raise ValueError("Target shape must be larger than original shape.")

    pad_h = (h1 - h) // 2
    pad_w = (w1 - w) // 2

    padded_matrix = np.zeros((h1, w1, c))
    padded_matrix[pad_h : pad_h + h, pad_w : pad_w + w, :] = matrix

    return padded_matrix


def pad_tensor(tensor, shape):
    """
    将输入的numpy array tensor进行0填充,直到其尺寸达到目标尺寸shape。

    参数:
    tensor: numpy array,输入的tensor
    shape: tuple,目标尺寸

    返回值:
    numpy array,填充后的tensor
    """
    # 获取tensor的尺寸
    tensor_shape = tensor.shape
    # 计算需要填充的尺寸
    pad_shape = tuple(
        np.maximum(np.zeros_like(shape), np.array(shape) - np.array(tensor_shape))
    )
    # pad_shape = (np.max(0, shape[i] - tensor_shape[i]) for i in range(len(shape)))
    # 构造填充后的tensor
    pad_shape_ = ((0, x) for x in pad_shape)
    padded_tensor = np.pad(
        tensor,
        ((0, pad_shape[0]), (0, pad_shape[1]), (0, pad_shape[2]), (0, pad_shape[3])),
        # pad_shape_,
        mode="constant",
    )
    return padded_tensor


def batch_dynamic_crop_resize_images_v2(
    images: Union[torch.Tensor, np.ndarray],
    target_height: int,
    target_width: int,
    mode=Image.Resampling.LANCZOS,
) -> np.ndarray:
    """获取图像中心有效部分,并resize到对应目标宽度和高度。
        如果图像宽高比大于 target_width / target_height,则保留全部高,截取宽的中心部位;
        如果图像宽高比小于 target_width / target_height,则保留全部宽,截取高的中心部位;
        最后,将截取的图像resize到目标宽高
    Args:
        image (Image.Image): 输入图像
        target_height (int): 目标高
        target_width (int): 目标宽

    Returns:
        Image.Image: 动态截取、resize生成的图像
    """
    ndim = images.ndim
    if ndim == 4:
        b, c, h, w = images.shape
    elif ndim == 5:
        b, c, t, h, w = images.shape
        images = rearrange(images, "b c t h w->(b t) c h w")
    else:
        raise ValueError(f"ndim only support 4, 5 but given {ndim}")
    images = convert_images(
        images, data_channel_order="b c h w", return_type="pil", input_rgb_order="rgb"
    )
    images = [
        dynamic_crop_resize_image(
            image,
            target_height=target_height,
            target_width=target_width,
            resample=mode,
        )
        for image in images
    ]
    images = [to_numpy_array(x) for x in images]
    images = np.stack(images, axis=0)
    images = rearrange(images, "b h w c-> b c h w")
    if ndim == 5:
        images = rearrange(images, "(b t) c h w->b c t h w", b=b, t=t)
    return images


def batch_dynamic_crop_resize_images(
    images: Union[torch.Tensor, np.ndarray],
    target_height: int,
    target_width: int,
    mode: Literal[
        "nearest", "linear", "bilinear", "bicubic", "trilinear", "area", "nearest-exact"
    ] = "bilinear",
    # ] = "nearest",
    align_corners=False,
) -> torch.TensorType:
    """获取图像中心有效部分,并resize到对应目标宽度和高度。
        如果图像宽高比大于 target_width / target_height,则保留全部高,截取宽的中心部位;
        如果图像宽高比小于 target_width / target_height,则保留全部宽,截取高的中心部位;
        最后,将截取的图像resize到目标宽高

    Warning: 该方法对于 b c t h w t=1时 会出现图像像素错位问题,所以新增了个使用Image.Resize的V2版本
    Args:
        image (Image.Image): 输入图像
        target_height (int): 目标高
        target_width (int): 目标宽

    Returns:
        Image.Image: 动态截取、resize生成的图像
    """
    if isinstance(images, np.ndarray):
        images = torch.from_numpy(images)
    ndim = images.ndim

    if ndim == 4:
        b, c, h, w = images.shape
    elif ndim == 5:
        b, c, t, h, w = images.shape
        images = rearrange(images, "b c t h w->(b t) c h w")
    else:
        raise ValueError(f"ndim only support 4, 5 but given {ndim}")
    image_width_heigt_ratio = w / h
    target_width_height_ratio = target_width / target_height
    if image_width_heigt_ratio >= target_width_height_ratio:
        y1 = 0
        y2 = h - 1
        x1 = math.ceil((w - h * target_width / target_height) / 2)
        x2 = math.ceil(w - (w - h * target_width / target_height) / 2)
    else:
        x1 = 0
        x2 = w - 1
        y1 = math.ceil((h - w * target_height / target_width) / 2)
        y2 = math.ceil(h - (h - w * target_height / target_width) / 2)
    x1 = max(0, x1)
    x2 = min(x2, w - 1)
    y1 = max(0, y1)
    y2 = min(y2, h - 1)
    images = images[:, :, y1:y2, x1:x2]
    images = torch.nn.functional.interpolate(
        images,
        (target_height, target_width),
        mode=mode,  # align_corners=align_corners
    )
    if ndim == 5:
        images = rearrange(images, "(b t) c h w->b c t h w", b=b, t=t)
    return images


def his_match(src: np.ndarray, dst: np.ndarray) -> np.ndarray:
    src = src * 255.0
    dst = dst * 255.0
    src = src.astype(np.uint8)
    dst = dst.astype(np.uint8)
    res = np.zeros_like(dst)

    cdf_src = np.zeros((3, 256))
    cdf_dst = np.zeros((3, 256))
    cdf_res = np.zeros((3, 256))
    kw = dict(bins=256, range=(0, 256), density=True)
    for ch in range(3):
        his_src, _ = np.histogram(src[:, :, ch], **kw)
        hist_dst, _ = np.histogram(dst[:, :, ch], **kw)
        cdf_src[ch] = np.cumsum(his_src)
        cdf_dst[ch] = np.cumsum(hist_dst)
        index = np.searchsorted(cdf_src[ch], cdf_dst[ch], side="left")
        np.clip(index, 0, 255, out=index)
        res[:, :, ch] = index[dst[:, :, ch]]
        his_res, _ = np.histogram(res[:, :, ch], **kw)
        cdf_res[ch] = np.cumsum(his_res)
    return res / 255.0