UniWorld-V1 / univa /utils /get_mask.py
LinB203
init
0c8d55e
from typing import List
import numpy as np
import cv2
from PIL import Image, ImageChops
import torch
import torch.nn.functional as F
def concat_images_row(images: List[Image.Image], bg_color=(0, 0, 0)) -> Image.Image:
"""
将多张 PIL 图像按行(横向)拼接在一起。
Args:
images: 要拼接的图像列表。
bg_color: 背景颜色,默认黑色;如果图像有透明通道,可以用 (0,0,0,0)。
Returns:
一张横向拼接后的新图。
"""
if not images:
raise ValueError("images 列表不能为空")
# 统一 mode,如果有透明通道,则用 RGBA,否则用 RGB
modes = {img.mode for img in images}
mode = "RGBA" if any(m.endswith("A") for m in modes) else "RGB"
# 计算拼接后画布的尺寸
widths, heights = zip(*(img.size for img in images))
total_width = sum(widths)
max_height = max(heights)
# 新建画布
canvas = Image.new(mode, (total_width, max_height), bg_color)
# 依次粘贴
x_offset = 0
for img in images:
# 如果 img 的 mode 不同,先转换
if img.mode != mode:
img = img.convert(mode)
canvas.paste(img, (x_offset, 0), img if mode=="RGBA" else None)
x_offset += img.width
return canvas
def downsample_mask_pytorch(pil_mask: Image.Image, factor: int) -> Image.Image:
"""
用 PyTorch 的 max_pool2d 对二值 mask 进行下采样,保留块内任何白色。
Args:
pil_mask: mode='1' 或 'L' 的二值 PIL 图(0/255)。
factor: 下采样倍数(stride 和 kernel 大小都设为这个值)。
Returns:
下采样后的二值 PIL Image(mode='1')。
"""
# 转成 0/1 float tensor,形状 [1,1,H,W]
arr = np.array(pil_mask.convert('L'), dtype=np.uint8)
tensor = torch.from_numpy(arr).float().div_(255.0).unsqueeze(0).unsqueeze(0)
# 用 max_pool2d,下采样
pooled = F.max_pool2d(tensor, kernel_size=factor, stride=factor)
# 恢复成 0/255,并转回 PIL
out = (pooled.squeeze(0).squeeze(0) > 0).to(torch.uint8).mul_(255).cpu().numpy()
return Image.fromarray(out, mode='L').convert('1')
def create_all_white_like(pil_img: Image.Image) -> Image.Image:
"""
给定一个 PIL 图像,返回一张同样大小的全白二值图(mode='1')。
"""
w, h = pil_img.size
white_array = np.ones((h, w), dtype=np.uint8) * 255 # 注意 shape 是 (H, W)
return Image.fromarray(white_array, mode='L').convert('1')
def union_masks_np(masks: List[Image.Image]) -> Image.Image:
"""
接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图),
返回它们的并集。
"""
if not masks:
raise ValueError("输入的 masks 列表不能为空")
# 把每张图都转成 0/1 numpy 数组
bin_arrays = []
for m in masks:
arr = np.array(m.convert('L'), dtype=np.uint8)
bin_arr = (arr > 127).astype(np.bool_)
bin_arrays.append(bin_arr)
# 做逐像素逻辑或
union_bool = np.logical_or.reduce(bin_arrays)
# 恢复成 0/255 uint8
union_arr = union_bool.astype(np.uint8) * 255
# 转回 PIL(二值)
return Image.fromarray(union_arr, mode='L').convert('1')
def intersect_masks_np(masks: List[Image.Image]) -> Image.Image:
"""
接受一个 PIL.Image 列表(mode='1' 或 'L' 的二值图),
返回它们的交集。
"""
if not masks:
raise ValueError("输入的 masks 列表不能为空")
# 把每张图都转成 0/1 numpy 数组
bin_arrays = []
for m in masks:
arr = np.array(m.convert('L'), dtype=np.uint8)
bin_arr = (arr > 127).astype(np.bool_)
bin_arrays.append(bin_arr)
# 做逐像素逻辑或
intersect_bool = np.logical_and.reduce(bin_arrays)
# 恢复成 0/255 uint8
intersect_arr = intersect_bool.astype(np.uint8) * 255
# 转回 PIL(二值)
return Image.fromarray(intersect_arr, mode='L').convert('1')
def close_small_holes(pil_mask, kernel_size=5):
"""
用闭运算填平小的黑点。
kernel_size: 结构元尺寸,越大能填的洞越大,通常取奇数。
"""
# 1. 转成 0/255 二值
mask = np.array(pil_mask.convert('L'))
_, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# 2. 定义结构元
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
# 3. 闭运算
closed = cv2.morphologyEx(bin_mask, cv2.MORPH_CLOSE, kernel)
return Image.fromarray(closed)
def get_mask(src_image, tgt_image, threshold=1):
"""
差异大的地方(差值大)当成前景(白),否则当背景(黑)
"""
diff = ImageChops.difference(src_image, tgt_image)
diff_gray = diff.convert("L")
mask = diff_gray.point(lambda x: 255 if x >= threshold else 0).convert("1")
return mask
def filter_small_components(pil_mask, area_threshold=0.10):
"""
删除小于 area_threshold (默认 10%)的连通白色区域。
pil_mask: PIL.Image,mode='L' 或 '1'(0/255 二值图)
area_threshold: 阈值,相对于整张图面积的比例
返回: 处理后的 PIL.Image
"""
# 1. 转为二值 NumPy 数组(0,255)
mask = np.array(pil_mask.convert('L'))
# 确保是 0/255
_, bin_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
# 2. 连通组件标记(4 或 8 邻域都可以)
num_labels, labels = cv2.connectedComponents(bin_mask, connectivity=8)
h, w = bin_mask.shape
total_area = h * w
total_area = np.count_nonzero(bin_mask)
# 3. 遍历各个连通块
output = np.zeros_like(bin_mask)
for lbl in range(1, num_labels): # 0 是背景
# 取出该连通块
comp_mask = (labels == lbl)
comp_area = comp_mask.sum()
# 面积比
if comp_area >= area_threshold * total_area:
# 保留
output[comp_mask] = 255
# 4. 转回 PIL
return Image.fromarray(output)
def is_binary_255(t: torch.Tensor) -> bool:
"""
判断给定的 tensor 是否只包含 0 和 255 两种值。
"""
unique_vals = torch.unique(t)
return torch.equal(unique_vals, torch.tensor([0], dtype=t.dtype)) or \
torch.equal(unique_vals, torch.tensor([255], dtype=t.dtype)) or \
torch.equal(unique_vals, torch.tensor([0, 255], dtype=t.dtype))
def get_weight(mask_u_ds, weight_type='log'):
mask_u_ds_tensor = torch.from_numpy(np.array(mask_u_ds)).float()
assert is_binary_255(mask_u_ds_tensor), "is_binary_255(mask_u_ds_tensor)"
mask_u_ds_tensor_bool = mask_u_ds_tensor.bool()
x = mask_u_ds_tensor_bool.numel() / mask_u_ds_tensor_bool.sum()
if weight_type == 'log':
weight = torch.log2(x) + 1
elif weight_type == 'exp':
weight = 2 ** (x**0.5 - 1)
else:
raise NotImplementedError(f'Support log | exp, but found {weight_type}')
weight = torch.round(weight, decimals=6)
assert weight >= 1, \
f"weight >= 1 but {weight}, {mask_u_ds_tensor_bool.shape}, mask_u_ds_tensor_bool.numel(): {mask_u_ds_tensor_bool.numel()}, mask_u_ds_tensor_bool.sum(): {mask_u_ds_tensor_bool.sum()}"
mask_u_ds_tensor[mask_u_ds_tensor==255] = weight
mask_u_ds_tensor[mask_u_ds_tensor==0] = 1.0
return mask_u_ds_tensor.unsqueeze(0) # h w -> 1 h w
def get_weight_mask(pil_pixel_values, prompt=None, weight_type='log', need_weight='true'):
# area_threshold = 1/64
area_threshold = 0.001
# base_kernel_size_factor = (5 / 448) ** 2
# if len(pil_pixel_values) > 0:
# w, h = pil_pixel_values[-1].size
# kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3)
# else:
kernel_size = 5
if need_weight.lower() == 'false':
mask_intersect = create_all_white_like(pil_pixel_values[-1])
mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8) # factor is downsample ratio of vae
mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size)
weight = get_weight(mask_intersect_ds, weight_type)
return mask_intersect_ds, weight
filtered_masks = []
for ii, j in enumerate(pil_pixel_values[:-1]):
# each reference image will compare with target image to get mask
mask = get_mask(j, pil_pixel_values[-1], threshold=18)
# fill small holes
fill_mask = close_small_holes(mask, kernel_size=kernel_size)
# del small components
filtered_mask = filter_small_components(fill_mask, area_threshold=0.3)
# filtered_mask = fill_mask
filtered_masks.append(filtered_mask)
if len(filtered_masks) == 0:
# t2i task do not have reference image
assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1"
mask_intersect = create_all_white_like(pil_pixel_values[-1])
else:
mask_intersect = intersect_masks_np(filtered_masks)
# while area / total area muse greater than 1/16 (just a threshold)
mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape)
# print(mask_intersect_area_ratio)
if mask_intersect_area_ratio < area_threshold:
if mask_intersect_area_ratio == 0.0:
# mask_intersect_area_ratio == 0 mean reconstruct data in stage 1
assert len(pil_pixel_values) == 2, "len(pil_pixel_values) == 2"
mask_intersect = create_all_white_like(pil_pixel_values[-1])
else:
# concat_images_row(pil_pixel_values + [mask_intersect], bg_color=(255,255,255)).show()
raise ValueError(f'TOO SMALL mask_intersect_area_ratio: {mask_intersect_area_ratio}, prompt: {prompt}')
mask_intersect_ds = downsample_mask_pytorch(mask_intersect, factor=8) # factor is downsample ratio of vae
mask_intersect_ds = close_small_holes(mask_intersect_ds, kernel_size=kernel_size)
weight = get_weight(mask_intersect_ds, weight_type)
return mask_intersect_ds, weight
def get_weight_mask_test(pil_pixel_values, prompt=None, weight_type='log'):
area_threshold = 1/64
base_kernel_size_factor = (5 / 448) ** 2
if len(pil_pixel_values) > 0:
w, h = pil_pixel_values[-1].size
kernel_size = max(int((base_kernel_size_factor * h * w) ** 0.5), 3)
else:
kernel_size = 5
filtered_masks = []
for ii, j in enumerate(pil_pixel_values[:-1]):
# each reference image will compare with target image to get mask
mask = get_mask(j, pil_pixel_values[-1], threshold=18)
# fill small holes
fill_mask = close_small_holes(mask, kernel_size=kernel_size)
# del small components
filtered_mask = filter_small_components(fill_mask, area_threshold=1/64)
# filtered_mask = fill_mask
filtered_masks.append(filtered_mask)
if len(filtered_masks) == 0:
# t2i task do not have reference image
assert len(pil_pixel_values) == 1, "len(pil_pixel_values) == 1"
mask_intersect = create_all_white_like(pil_pixel_values[-1])
else:
mask_intersect = intersect_masks_np(filtered_masks)
# while area / total area muse greater than 1/16 (just a threshold)
mask_intersect_area_ratio = np.array(mask_intersect).astype(np.float32).sum() / np.prod(np.array(mask_intersect).shape)
return mask_intersect_area_ratio