Spaces:
Runtime error
Runtime error
File size: 11,458 Bytes
0c8d55e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
from typing import List
import numpy as np
import cv2
from PIL import Image, ImageChops
import torch
import torch.nn.functional as F
def concat_images_row(images: List[Image.Image], bg_color=(0, 0, 0)) -> Image.Image:
"""
将多张 PIL 图像按行(横向)拼接在一起。
Args:
images: 要拼接的图像列表。
bg_color: 背景颜色,默认黑色;如果图像有透明通道,可以用 (0,0,0,0)。
Returns:
一张横向拼接后的新图。
"""
if not images:
raise ValueError("images 列表不能为空")
# 统一 mode,如果有透明通道,则用 RGBA,否则用 RGB
modes = {img.mode for img in images}
mode = "RGBA" if any(m.endswith("A") for m in modes) else "RGB"
# 计算拼接后画布的尺寸
widths, heights = zip(*(img.size for img in images))
total_width = sum(widths)
max_height = max(heights)
# 新建画布
canvas = Image.new(mode, (total_width, max_height), bg_color)
# 依次粘贴
x_offset = 0
for img in images:
# 如果 img 的 mode 不同,先转换
if img.mode != mode:
img = img.convert(mode)
canvas.paste(img, (x_offset, 0), img if mode=="RGBA" else None)
x_offset += img.width
return canvas
def downsample_mask_pytorch(pil_mask: Image.Image, factor: int) -> Image.Image:
"""
用 PyTorch 的 max_pool2d 对二值 mask 进行下采样,保留块内任何白色。
Args:
pil_mask: mode='1' 或 'L' 的二值 PIL 图(0/255)。
factor: 下采样倍数(stride 和 kernel 大小都设为这个值)。
Returns:
下采样后的二值 PIL Image(mode='1')。
"""
# 转成 0/1 float tensor,形状 [1,1,H,W]
arr = np.array(pil_mask.convert('L'), dtype=np.uint8)
tensor = torch.from_numpy(arr).float().div_(255.0).unsqueeze(0).unsqueeze(0)
# 用 max_pool2d,下采样
pooled = F.max_pool2d(tensor, kernel_size=factor, stride=factor)
# 恢复成 0/255,并转回 PIL
out = (pooled.squeeze(0).squeeze(0) > 0).to(torch.uint8).mul_(255).cpu().numpy()
return Image.fromarray(out, mode='L').convert('1')
def create_all_white_like(pil_img: Image.Image) -> Image.Image:
"""
给定一个 PIL 图像,返回一张同样大小的全白二值图(mode='1')。
"""
w, h = pil_img.size
white_array = np.ones((h, w), dtype=np.uint8) * 255 # 注意 shape 是 (H, W)
return Image.fromarray(white_array, mode='L').convert('1')
def union_masks_np(masks: List[Image.Image]) -> Image.Image:
"""
接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图),
返回它们的并集。
"""
if not masks:
raise ValueError("输入的 masks 列表不能为空")
# 把每张图都转成 0/1 numpy 数组
bin_arrays = []
for m in masks:
arr = np.array(m.convert('L'), dtype=np.uint8)
bin_arr = (arr > 127).astype(np.bool_)
bin_arrays.append(bin_arr)
# 做逐像素逻辑或
union_bool = np.logical_or.reduce(bin_arrays)
# 恢复成 0/255 uint8
union_arr = union_bool.astype(np.uint8) * 255
# 转回 PIL(二值)
return Image.fromarray(union_arr, mode='L').convert('1')
def intersect_masks_np(masks: List[Image.Image]) -> Image.Image:
"""
接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图),
返回它们的交集。
"""
if not masks:
raise ValueError("输入的 masks 列表不能为空")
# 把每张图都转成 0/1 numpy 数组
bin_arrays = []
for m in masks:
arr = np.array(m.convert('L'), dtype=np.uint8)
bin_arr = (arr > 127).astype(np.bool_)
bin_arrays.append(bin_arr)
# 做逐像素逻辑或
intersect_bool = np.logical_and.reduce(bin_arrays)
# 恢复成 0/255 uint8
intersect_arr = intersect_bool.astype(np.uint8) * 255
# 转回 PIL(二值)
return Image.fromarray(intersect_arr, mode='L').convert('1')
def close_small_holes(pil_mask, kernel_size=5):
"""
用闭运算填平小的黑点。
kernel_size: 结构元尺寸,越大能填的洞越大,通常取奇数。
"""
# 1. 转成 0/255 二值
mask = np.array(pil_mask.convert('L'))
_, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# 2. 定义结构元
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
# 3. 闭运算
closed = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel)
return Image.fromarray(closed)
def get_mask(src_image, tgt_image, threshold=1):
"""
差异大的地方(差值大)当成前景(白),否则当背景(黑)
"""
diff = ImageChops.difference(src_image, tgt_image)
diff_gray = diff.convert("L")
mask = diff_gray.point(lambda x: 255 if x >= threshold else 0).convert("1")
return mask
def filter_small_components(pil_mask, area_threshold=0.10):
"""
删除小于 area_threshold (默认 10%)的连通白色区域。
pil_mask: PIL.Image,mode='L' 或 '1'(0/255 二值图)
area_threshold: 阈值,相对于整张图面积的比例
返回: 处理后的 PIL.Image
"""
# 1. 转为二值 NumPy 数组(0,255)
mask = np.array(pil_mask.convert('L'))
# 确保是 0/255
_, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# 2. 连通组件标记(4 或 8 邻域都可以)
num_labels, labels = cv2.connectedComponents(bin_mask, connectivity=8)
h, w = bin_mask.shape
total_area = h * w
total_area = np.count_nonzero(bin_mask)
# 3. 遍历各个连通块
output = np.zeros_like(bin_mask)
for lbl in range(1, num_labels): # 0 是背景
# 取出该连通块
comp_mask = (labels == lbl)
comp_area = comp_mask.sum()
# 面积比
if comp_area >= area_threshold * total_area:
# 保留
output[comp_mask] = 255
# 4. 转回 PIL
return Image.fromarray(output)
def is_binary_255(t: torch.Tensor) -> bool:
"""
判断给定的 tensor 是否只包含 0 和 255 两种值。
"""
unique_vals = torch.unique(t)
return torch.equal(unique_vals, torch.tensor([0], dtype=t.dtype)) or \
torch.equal(unique_vals, torch.tensor([255], dtype=t.dtype)) or \
torch.equal(unique_vals, torch.tensor([0, 255], dtype=t.dtype))
def get_weight(mask_u_ds, weight_type='log'):
mask_u_ds_tensor = torch.from_numpy(np.array(mask_u_ds)).float()
assert is_binary_255(mask_u_ds_tensor), "is_binary_255(mask_u_ds_tensor)"
mask_u_ds_tensor_bool = mask_u_ds_tensor.bool()
x = mask_u_ds_tensor_bool.numel() / mask_u_ds_tensor_bool.sum()
if weight_type == 'log':
weight = torch.log2(x) + 1
elif weight_type == 'exp':
weight = 2 ** (x**0.5 - 1)
else:
raise NotImplementedError(f'Support log | exp, but found {weight_type}')
weight = torch.round(weight, decimals=6)
assert weight >= 1, \
f"weight >= 1 but {weight}, {mask_u_ds_tensor_bool.shape}, mask_u_ds_tensor_bool.numel(): {mask_u_ds_tensor_bool.numel()}, mask_u_ds_tensor_bool.sum(): {mask_u_ds_tensor_bool.sum()}"
mask_u_ds_tensor[mask_u_ds_tensor==255] = weight
mask_u_ds_tensor[mask_u_ds_tensor==0] = 1.0
return mask_u_ds_tensor.unsqueeze(0) # h w -> 1 h w
def get_weight_mask(pil_pixel_values, prompt=None, weight_type='log', need_weight='true'):
# area_threshold = 1/64
area_threshold = 0.001
# base_kernel_size_factor = (5 / 448) ** 2
# if len(pil_pixel_values) > 0:
# w, h = pil_pixel_values[-1].size
# kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3)
# else:
kernel_size = 5
if need_weight.lower() == 'false':
mask_intersect = create_all_white_like(pil_pixel_values[-1])
mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8) # factor is downsample ratio of vae
mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size)
weight = get_weight(mask_intersect_ds, weight_type)
return mask_intersect_ds, weight
filtered_masks = []
for ii, j in enumerate(pil_pixel_values[:-1]):
# each reference image will compare with target image to get mask
mask = get_mask(j, pil_pixel_values[-1], threshold=18)
# fill small holes
fill_mask = close_small_holes(mask, kernel_size=kernel_size)
# del small components
filtered_mask = filter_small_components(fill_mask, area_threshold=0.3)
# filtered_mask = fill_mask
filtered_masks.append(filtered_mask)
if len(filtered_masks) == 0:
# t2i task do not have reference image
assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1"
mask_intersect = create_all_white_like(pil_pixel_values[-1])
else:
mask_intersect = intersect_masks_np(filtered_masks)
# while area / total area muse greater than 1/16 (just a threshold)
mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape)
# print(mask_intersect_area_ratio)
if mask_intersect_area_ratio < area_threshold:
if mask_intersect_area_ratio == 0.0:
# mask_intersect_area_ratio == 0 mean reconstruct data in stage 1
assert len(pil_pixel_values) == 2, "len(pil_pixel_values) == 2"
mask_intersect = create_all_white_like(pil_pixel_values[-1])
else:
# concat_images_row(pil_pixel_values + [mask_intersect], bg_color=(255,255,255)).show()
raise ValueError(f'TOO SMALL mask_intersect_area_ratio: {mask_intersect_area_ratio}, prompt: {prompt}')
mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8) # factor is downsample ratio of vae
mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size)
weight = get_weight(mask_intersect_ds, weight_type)
return mask_intersect_ds, weight
def get_weight_mask_test(pil_pixel_values, prompt=None, weight_type='log'):
area_threshold = 1/64
base_kernel_size_factor = (5 / 448) ** 2
if len(pil_pixel_values) > 0:
w, h = pil_pixel_values[-1].size
kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3)
else:
kernel_size = 5
filtered_masks = []
for ii, j in enumerate(pil_pixel_values[:-1]):
# each reference image will compare with target image to get mask
mask = get_mask(j, pil_pixel_values[-1], threshold=18)
# fill small holes
fill_mask = close_small_holes(mask, kernel_size=kernel_size)
# del small components
filtered_mask = filter_small_components(fill_mask, area_threshold=1/64)
# filtered_mask = fill_mask
filtered_masks.append(filtered_mask)
if len(filtered_masks) == 0:
# t2i task do not have reference image
assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1"
mask_intersect = create_all_white_like(pil_pixel_values[-1])
else:
mask_intersect = intersect_masks_np(filtered_masks)
# while area / total area muse greater than 1/16 (just a threshold)
mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape)
return mask_intersect_area_ratio |