from typing import List import numpy as np import cv2 from PIL import Image, ImageChops import torch import torch.nn.functional as F def concat_images_row(images: List[Image.Image], bg_color=(0, 0, 0)) -> Image.Image: """ 将多张 PIL 图像按行(横向)拼接在一起。 Args: images: 要拼接的图像列表。 bg_color: 背景颜色,默认黑色;如果图像有透明通道,可以用 (0,0,0,0)。 Returns: 一张横向拼接后的新图。 """ if not images: raise ValueError("images 列表不能为空") # 统一 mode,如果有透明通道,则用 RGBA,否则用 RGB modes = {img.mode for img in images} mode = "RGBA" if any(m.endswith("A") for m in modes) else "RGB" # 计算拼接后画布的尺寸 widths, heights = zip(*(img.size for img in images)) total_width = sum(widths) max_height = max(heights) # 新建画布 canvas = Image.new(mode, (total_width, max_height), bg_color) # 依次粘贴 x_offset = 0 for img in images: # 如果 img 的 mode 不同,先转换 if img.mode != mode: img = img.convert(mode) canvas.paste(img, (x_offset, 0), img if mode=="RGBA" else None) x_offset += img.width return canvas def downsample_mask_pytorch(pil_mask: Image.Image, factor: int) -> Image.Image: """ 用 PyTorch 的 max_pool2d 对二值 mask 进行下采样,保留块内任何白色。 Args: pil_mask: mode='1' 或 'L' 的二值 PIL 图(0/255)。 factor: 下采样倍数(stride 和 kernel 大小都设为这个值)。 Returns: 下采样后的二值 PIL Image(mode='1')。 """ # 转成 0/1 float tensor,形状 [1,1,H,W] arr = np.array(pil_mask.convert('L'), dtype=np.uint8) tensor = torch.from_numpy(arr).float().div_(255.0).unsqueeze(0).unsqueeze(0) # 用 max_pool2d,下采样 pooled = F.max_pool2d(tensor, kernel_size=factor, stride=factor) # 恢复成 0/255,并转回 PIL out = (pooled.squeeze(0).squeeze(0) > 0).to(torch.uint8).mul_(255).cpu().numpy() return Image.fromarray(out, mode='L').convert('1') def create_all_white_like(pil_img: Image.Image) -> Image.Image: """ 给定一个 PIL 图像,返回一张同样大小的全白二值图(mode='1')。 """ w, h = pil_img.size white_array = np.ones((h, w), dtype=np.uint8) * 255 # 注意 shape 是 (H, W) return Image.fromarray(white_array, mode='L').convert('1') def union_masks_np(masks: List[Image.Image]) -> Image.Image: """ 接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图), 返回它们的并集。 """ if not masks: raise ValueError("输入的 masks 列表不能为空") # 把每张图都转成 0/1 numpy 数组 bin_arrays = [] for m in masks: arr = np.array(m.convert('L'), dtype=np.uint8) bin_arr = (arr > 127).astype(np.bool_) bin_arrays.append(bin_arr) # 做逐像素逻辑或 union_bool = np.logical_or.reduce(bin_arrays) # 恢复成 0/255 uint8 union_arr = union_bool.astype(np.uint8) * 255 # 转回 PIL(二值) return Image.fromarray(union_arr, mode='L').convert('1') def intersect_masks_np(masks: List[Image.Image]) -> Image.Image: """ 接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图), 返回它们的交集。 """ if not masks: raise ValueError("输入的 masks 列表不能为空") # 把每张图都转成 0/1 numpy 数组 bin_arrays = [] for m in masks: arr = np.array(m.convert('L'), dtype=np.uint8) bin_arr = (arr > 127).astype(np.bool_) bin_arrays.append(bin_arr) # 做逐像素逻辑或 intersect_bool = np.logical_and.reduce(bin_arrays) # 恢复成 0/255 uint8 intersect_arr = intersect_bool.astype(np.uint8) * 255 # 转回 PIL(二值) return Image.fromarray(intersect_arr, mode='L').convert('1') def close_small_holes(pil_mask, kernel_size=5): """ 用闭运算填平小的黑点。 kernel_size: 结构元尺寸,越大能填的洞越大,通常取奇数。 """ # 1. 转成 0/255 二值 mask = np.array(pil_mask.convert('L')) _, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) # 2. 定义结构元 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) # 3. 闭运算 closed = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel) return Image.fromarray(closed) def get_mask(src_image, tgt_image, threshold=1): """ 差异大的地方(差值大)当成前景(白),否则当背景(黑) """ diff = ImageChops.difference(src_image, tgt_image) diff_gray = diff.convert("L") mask = diff_gray.point(lambda x: 255 if x >= threshold else 0).convert("1") return mask def filter_small_components(pil_mask, area_threshold=0.10): """ 删除小于 area_threshold (默认 10%)的连通白色区域。 pil_mask: PIL.Image,mode='L' 或 '1'(0/255 二值图) area_threshold: 阈值,相对于整张图面积的比例 返回: 处理后的 PIL.Image """ # 1. 转为二值 NumPy 数组(0,255) mask = np.array(pil_mask.convert('L')) # 确保是 0/255 _, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) # 2. 连通组件标记(4 或 8 邻域都可以) num_labels, labels = cv2.connectedComponents(bin_mask, connectivity=8) h, w = bin_mask.shape total_area = h * w total_area = np.count_nonzero(bin_mask) # 3. 遍历各个连通块 output = np.zeros_like(bin_mask) for lbl in range(1, num_labels): # 0 是背景 # 取出该连通块 comp_mask = (labels == lbl) comp_area = comp_mask.sum() # 面积比 if comp_area >= area_threshold * total_area: # 保留 output[comp_mask] = 255 # 4. 转回 PIL return Image.fromarray(output) def is_binary_255(t: torch.Tensor) -> bool: """ 判断给定的 tensor 是否只包含 0 和 255 两种值。 """ unique_vals = torch.unique(t) return torch.equal(unique_vals, torch.tensor([0], dtype=t.dtype)) or \ torch.equal(unique_vals, torch.tensor([255], dtype=t.dtype)) or \ torch.equal(unique_vals, torch.tensor([0, 255], dtype=t.dtype)) def get_weight(mask_u_ds, weight_type='log'): mask_u_ds_tensor = torch.from_numpy(np.array(mask_u_ds)).float() assert is_binary_255(mask_u_ds_tensor), "is_binary_255(mask_u_ds_tensor)" mask_u_ds_tensor_bool = mask_u_ds_tensor.bool() x = mask_u_ds_tensor_bool.numel() / mask_u_ds_tensor_bool.sum() if weight_type == 'log': weight = torch.log2(x) + 1 elif weight_type == 'exp': weight = 2 ** (x**0.5 - 1) else: raise NotImplementedError(f'Support log | exp, but found {weight_type}') weight = torch.round(weight, decimals=6) assert weight >= 1, \ f"weight >= 1 but {weight}, {mask_u_ds_tensor_bool.shape}, mask_u_ds_tensor_bool.numel(): {mask_u_ds_tensor_bool.numel()}, mask_u_ds_tensor_bool.sum(): {mask_u_ds_tensor_bool.sum()}" mask_u_ds_tensor[mask_u_ds_tensor==255] = weight mask_u_ds_tensor[mask_u_ds_tensor==0] = 1.0 return mask_u_ds_tensor.unsqueeze(0) # h w -> 1 h w def get_weight_mask(pil_pixel_values, prompt=None, weight_type='log', need_weight='true'): # area_threshold = 1/64 area_threshold = 0.001 # base_kernel_size_factor = (5 / 448) ** 2 # if len(pil_pixel_values) > 0: # w, h = pil_pixel_values[-1].size # kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3) # else: kernel_size = 5 if need_weight.lower() == 'false': mask_intersect = create_all_white_like(pil_pixel_values[-1]) mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8) # factor is downsample ratio of vae mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size) weight = get_weight(mask_intersect_ds, weight_type) return mask_intersect_ds, weight filtered_masks = [] for ii, j in enumerate(pil_pixel_values[:-1]): # each reference image will compare with target image to get mask mask = get_mask(j, pil_pixel_values[-1], threshold=18) # fill small holes fill_mask = close_small_holes(mask, kernel_size=kernel_size) # del small components filtered_mask = filter_small_components(fill_mask, area_threshold=0.3) # filtered_mask = fill_mask filtered_masks.append(filtered_mask) if len(filtered_masks) == 0: # t2i task do not have reference image assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1" mask_intersect = create_all_white_like(pil_pixel_values[-1]) else: mask_intersect = intersect_masks_np(filtered_masks) # while area / total area muse greater than 1/16 (just a threshold) mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape) # print(mask_intersect_area_ratio) if mask_intersect_area_ratio < area_threshold: if mask_intersect_area_ratio == 0.0: # mask_intersect_area_ratio == 0 mean reconstruct data in stage 1 assert len(pil_pixel_values) == 2, "len(pil_pixel_values) == 2" mask_intersect = create_all_white_like(pil_pixel_values[-1]) else: # concat_images_row(pil_pixel_values + [mask_intersect], bg_color=(255,255,255)).show() raise ValueError(f'TOO SMALL mask_intersect_area_ratio: {mask_intersect_area_ratio}, prompt: {prompt}') mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8) # factor is downsample ratio of vae mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size) weight = get_weight(mask_intersect_ds, weight_type) return mask_intersect_ds, weight def get_weight_mask_test(pil_pixel_values, prompt=None, weight_type='log'): area_threshold = 1/64 base_kernel_size_factor = (5 / 448) ** 2 if len(pil_pixel_values) > 0: w, h = pil_pixel_values[-1].size kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3) else: kernel_size = 5 filtered_masks = [] for ii, j in enumerate(pil_pixel_values[:-1]): # each reference image will compare with target image to get mask mask = get_mask(j, pil_pixel_values[-1], threshold=18) # fill small holes fill_mask = close_small_holes(mask, kernel_size=kernel_size) # del small components filtered_mask = filter_small_components(fill_mask, area_threshold=1/64) # filtered_mask = fill_mask filtered_masks.append(filtered_mask) if len(filtered_masks) == 0: # t2i task do not have reference image assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1" mask_intersect = create_all_white_like(pil_pixel_values[-1]) else: mask_intersect = intersect_masks_np(filtered_masks) # while area / total area muse greater than 1/16 (just a threshold) mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape) return mask_intersect_area_ratio