File size: 11,458 Bytes
0c8d55e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from typing import List
import numpy as np
import cv2
from PIL import Image, ImageChops
import torch
import torch.nn.functional as F

def concat_images_row(images: List[Image.Image], bg_color=(0, 0, 0)) -> Image.Image:
    """
    将多张 PIL 图像按行(横向)拼接在一起。
    
    Args:
        images: 要拼接的图像列表。
        bg_color: 背景颜色,默认黑色;如果图像有透明通道,可以用 (0,0,0,0)。
    
    Returns:
        一张横向拼接后的新图。
    """
    if not images:
        raise ValueError("images 列表不能为空")
    
    # 统一 mode,如果有透明通道,则用 RGBA,否则用 RGB
    modes = {img.mode for img in images}
    mode = "RGBA" if any(m.endswith("A") for m in modes) else "RGB"

    # 计算拼接后画布的尺寸
    widths, heights = zip(*(img.size for img in images))
    total_width = sum(widths)
    max_height = max(heights)

    # 新建画布
    canvas = Image.new(mode, (total_width, max_height), bg_color)

    # 依次粘贴
    x_offset = 0
    for img in images:
        # 如果 img 的 mode 不同,先转换
        if img.mode != mode:
            img = img.convert(mode)
        canvas.paste(img, (x_offset, 0), img if mode=="RGBA" else None)
        x_offset += img.width

    return canvas


def downsample_mask_pytorch(pil_mask: Image.Image, factor: int) -> Image.Image:
    """
    用 PyTorch 的 max_pool2d 对二值 mask 进行下采样,保留块内任何白色。
    
    Args:
        pil_mask: mode='1' 或 'L' 的二值 PIL 图(0/255)。
        factor: 下采样倍数(stride 和 kernel 大小都设为这个值)。
    
    Returns:
        下采样后的二值 PIL Image(mode='1')。
    """
    # 转成 0/1 float tensor,形状 [1,1,H,W]
    arr = np.array(pil_mask.convert('L'), dtype=np.uint8)
    tensor = torch.from_numpy(arr).float().div_(255.0).unsqueeze(0).unsqueeze(0)
    
    # 用 max_pool2d,下采样
    pooled = F.max_pool2d(tensor, kernel_size=factor, stride=factor)
    
    # 恢复成 0/255,并转回 PIL
    out = (pooled.squeeze(0).squeeze(0) > 0).to(torch.uint8).mul_(255).cpu().numpy()
    return Image.fromarray(out, mode='L').convert('1')

def create_all_white_like(pil_img: Image.Image) -> Image.Image:
    """
    给定一个 PIL 图像,返回一张同样大小的全白二值图(mode='1')。
    """
    w, h = pil_img.size
    white_array = np.ones((h, w), dtype=np.uint8) * 255  # 注意 shape 是 (H, W)
    return Image.fromarray(white_array, mode='L').convert('1')

def union_masks_np(masks: List[Image.Image]) -> Image.Image:
    """
    接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图),
    返回它们的并集。
    """
    if not masks:
        raise ValueError("输入的 masks 列表不能为空")

    # 把每张图都转成 0/1 numpy 数组
    bin_arrays = []
    for m in masks:
        arr = np.array(m.convert('L'), dtype=np.uint8)
        bin_arr = (arr > 127).astype(np.bool_)
        bin_arrays.append(bin_arr)

    # 做逐像素逻辑或
    union_bool = np.logical_or.reduce(bin_arrays)

    # 恢复成 0/255 uint8
    union_arr = union_bool.astype(np.uint8) * 255

    # 转回 PIL(二值)
    return Image.fromarray(union_arr, mode='L').convert('1')

def intersect_masks_np(masks: List[Image.Image]) -> Image.Image:
    """
    接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图),
    返回它们的交集。
    """
    if not masks:
        raise ValueError("输入的 masks 列表不能为空")

    # 把每张图都转成 0/1 numpy 数组
    bin_arrays = []
    for m in masks:
        arr = np.array(m.convert('L'), dtype=np.uint8)
        bin_arr = (arr > 127).astype(np.bool_)
        bin_arrays.append(bin_arr)

    # 做逐像素逻辑或
    intersect_bool = np.logical_and.reduce(bin_arrays)

    # 恢复成 0/255 uint8
    intersect_arr = intersect_bool.astype(np.uint8) * 255

    # 转回 PIL(二值)
    return Image.fromarray(intersect_arr, mode='L').convert('1')

def close_small_holes(pil_mask, kernel_size=5):
    """
    用闭运算填平小的黑点。
    kernel_size: 结构元尺寸,越大能填的洞越大,通常取奇数。
    """
    # 1. 转成 0/255 二值
    mask = np.array(pil_mask.convert('L'))
    _, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    
    # 2. 定义结构元
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
    # 3. 闭运算
    closed = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel)
    
    return Image.fromarray(closed)


def get_mask(src_image, tgt_image, threshold=1):
    """
    差异大的地方(差值大)当成前景(白),否则当背景(黑)
    """
    diff = ImageChops.difference(src_image, tgt_image)
    diff_gray = diff.convert("L")
    mask = diff_gray.point(lambda x: 255 if x >= threshold else 0).convert("1")
    return mask

def filter_small_components(pil_mask, area_threshold=0.10):
    """
    删除小于 area_threshold (默认 10%)的连通白色区域。
    pil_mask: PIL.Image,mode='L' 或 '1'(0/255 二值图)
    area_threshold: 阈值,相对于整张图面积的比例
    返回: 处理后的 PIL.Image
    """
    # 1. 转为二值 NumPy 数组(0,255)
    mask = np.array(pil_mask.convert('L'))
    # 确保是 0/255
    _, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
    
    # 2. 连通组件标记(4 或 8 邻域都可以)
    num_labels, labels = cv2.connectedComponents(bin_mask, connectivity=8)
    
    h, w = bin_mask.shape
    total_area = h * w
    total_area = np.count_nonzero(bin_mask)
    
    # 3. 遍历各个连通块
    output = np.zeros_like(bin_mask)
    for lbl in range(1, num_labels):  # 0 是背景
        # 取出该连通块
        comp_mask = (labels == lbl)
        comp_area = comp_mask.sum()
        # 面积比
        if comp_area >= area_threshold * total_area:
            # 保留
            output[comp_mask] = 255
    
    # 4. 转回 PIL
    return Image.fromarray(output)



def is_binary_255(t: torch.Tensor) -> bool:
    """
    判断给定的 tensor 是否只包含 0 和 255 两种值。
    """
    unique_vals = torch.unique(t)
    return torch.equal(unique_vals, torch.tensor([0], dtype=t.dtype)) or \
           torch.equal(unique_vals, torch.tensor([255], dtype=t.dtype)) or \
           torch.equal(unique_vals, torch.tensor([0, 255], dtype=t.dtype))

def get_weight(mask_u_ds, weight_type='log'):
    mask_u_ds_tensor = torch.from_numpy(np.array(mask_u_ds)).float()
    assert is_binary_255(mask_u_ds_tensor), "is_binary_255(mask_u_ds_tensor)"
    mask_u_ds_tensor_bool = mask_u_ds_tensor.bool()
    x = mask_u_ds_tensor_bool.numel() / mask_u_ds_tensor_bool.sum()
    if weight_type == 'log':
        weight = torch.log2(x) + 1
    elif weight_type == 'exp':
        weight = 2 ** (x**0.5 - 1)
    else:
        raise NotImplementedError(f'Support log | exp, but found {weight_type}')
    weight = torch.round(weight, decimals=6)
    assert weight >= 1, \
        f"weight >= 1 but {weight}, {mask_u_ds_tensor_bool.shape}, mask_u_ds_tensor_bool.numel(): {mask_u_ds_tensor_bool.numel()}, mask_u_ds_tensor_bool.sum(): {mask_u_ds_tensor_bool.sum()}"
    mask_u_ds_tensor[mask_u_ds_tensor==255] = weight
    mask_u_ds_tensor[mask_u_ds_tensor==0] = 1.0
    return mask_u_ds_tensor.unsqueeze(0)  # h w -> 1 h w

def get_weight_mask(pil_pixel_values, prompt=None, weight_type='log', need_weight='true'):
    # area_threshold = 1/64
    area_threshold = 0.001
    # base_kernel_size_factor = (5 / 448) ** 2
    # if len(pil_pixel_values) > 0:
    #     w, h = pil_pixel_values[-1].size
    #     kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3)
    # else:
    kernel_size = 5

    if need_weight.lower() == 'false':
        mask_intersect = create_all_white_like(pil_pixel_values[-1])
        mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8)  # factor is downsample ratio of vae
        mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size)
        weight = get_weight(mask_intersect_ds, weight_type)
        return mask_intersect_ds, weight

    filtered_masks = []
    for ii, j in enumerate(pil_pixel_values[:-1]):
        # each reference image will compare with target image to get mask
        mask = get_mask(j, pil_pixel_values[-1], threshold=18)
        # fill small holes
        fill_mask = close_small_holes(mask, kernel_size=kernel_size)
        # del small components
        filtered_mask = filter_small_components(fill_mask, area_threshold=0.3)
        # filtered_mask = fill_mask
        filtered_masks.append(filtered_mask)
    if len(filtered_masks) == 0:
        # t2i task do not have reference image
        assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1"
        mask_intersect = create_all_white_like(pil_pixel_values[-1])
    else:
        mask_intersect = intersect_masks_np(filtered_masks)
    # while area / total area muse greater than 1/16 (just a threshold)
    mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape)
    # print(mask_intersect_area_ratio)
    if mask_intersect_area_ratio < area_threshold:
        if mask_intersect_area_ratio == 0.0:
            # mask_intersect_area_ratio == 0 mean reconstruct data in stage 1
            assert len(pil_pixel_values) == 2, "len(pil_pixel_values) == 2"
            mask_intersect = create_all_white_like(pil_pixel_values[-1])
        else:
            # concat_images_row(pil_pixel_values + [mask_intersect], bg_color=(255,255,255)).show()
            raise ValueError(f'TOO SMALL mask_intersect_area_ratio: {mask_intersect_area_ratio}, prompt: {prompt}')
    mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8)  # factor is downsample ratio of vae
    mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size)
    weight = get_weight(mask_intersect_ds, weight_type)
    return mask_intersect_ds, weight



def get_weight_mask_test(pil_pixel_values, prompt=None, weight_type='log'):
    area_threshold = 1/64
    base_kernel_size_factor = (5 / 448) ** 2
    if len(pil_pixel_values) > 0:
        w, h = pil_pixel_values[-1].size
        kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3)
    else:
        kernel_size = 5

    filtered_masks = []
    for ii, j in enumerate(pil_pixel_values[:-1]):
        # each reference image will compare with target image to get mask
        mask = get_mask(j, pil_pixel_values[-1], threshold=18)
        # fill small holes
        fill_mask = close_small_holes(mask, kernel_size=kernel_size)
        # del small components
        filtered_mask = filter_small_components(fill_mask, area_threshold=1/64)
        # filtered_mask = fill_mask
        filtered_masks.append(filtered_mask)
    if len(filtered_masks) == 0:
        # t2i task do not have reference image
        assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1"
        mask_intersect = create_all_white_like(pil_pixel_values[-1])
    else:
        mask_intersect = intersect_masks_np(filtered_masks)
    # while area / total area muse greater than 1/16 (just a threshold)
    mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape)
    return mask_intersect_area_ratio