|
from dataclasses import dataclass, field |
|
from typing import List, Union, Optional, Tuple |
|
from enum import IntEnum |
|
import os |
|
import cv2 |
|
import torch |
|
import numpy as np |
|
from PIL import Image, ImageDraw, ImageFilter, ImageOps |
|
from torchvision.transforms.functional import to_pil_image |
|
|
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
|
|
|
|
MASK_MERGE_INVERT = ["None", "Merge", "Merge and Invert"] |
|
|
|
|
|
def adetailer(sd_pipeline, yolodetector, images: list[Image.Image], prompt, negative_prompt, seed=42): |
|
resolution = 512 |
|
|
|
processed_input_imgs = [] |
|
for input_image in images: |
|
pred = ultralytics_predict(yolodetector_model=yolodetector, image=input_image) |
|
masks = pred_preprocessing(pred) |
|
for i_mask, mask in enumerate(masks): |
|
|
|
|
|
|
|
blurred_mask = mask.filter(ImageFilter.GaussianBlur(8)) |
|
crop_region = get_crop_region(np.array(blurred_mask)) |
|
crop_region = expand_crop_region(crop_region, resolution, resolution, mask.width, mask.height) |
|
x1, y1, x2, y2 = crop_region |
|
paste_to = (x1, y1, x2-x1, y2-y1) |
|
image_mask = blurred_mask.crop(crop_region) |
|
image_mask = image_mask.resize((resolution, resolution), Image.LANCZOS) |
|
|
|
image_masked = Image.new('RGBa', (input_image.width, input_image.height)) |
|
image_masked.paste(input_image.convert("RGBA"), mask=ImageOps.invert(blurred_mask.convert('L'))) |
|
overlay_image = image_masked.convert('RGBA') |
|
|
|
patch_input_img = input_image.crop(crop_region) |
|
patch_input_img = patch_input_img.resize((resolution, resolution), Image.LANCZOS) |
|
processed_input_imgs.append([patch_input_img, paste_to, overlay_image]) |
|
|
|
denoising_strength = 0.4 |
|
|
|
pipe = StableDiffusionInpaintPipeline( |
|
vae=sd_pipeline.vae, |
|
text_encoder=sd_pipeline.text_encoder, |
|
tokenizer=sd_pipeline.tokenizer, |
|
unet=sd_pipeline.unet, |
|
scheduler=sd_pipeline.scheduler, |
|
requires_safety_checker=False, |
|
safety_checker=None, |
|
feature_extractor=sd_pipeline.feature_extractor, |
|
).to('cuda') |
|
|
|
generator = torch.Generator(device="cuda").manual_seed(seed) |
|
|
|
inpaint_images = [] |
|
for i in range(len(processed_input_imgs)): |
|
out = pipe( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=[processed_input_imgs[i][0]], |
|
mask_image=image_mask, |
|
num_inference_steps=30, |
|
strength=denoising_strength, |
|
controlnet_conditioning_scale=1.0, |
|
generator=generator |
|
).images[0] |
|
|
|
paste_to = processed_input_imgs[i][1] |
|
overlay_image = processed_input_imgs[i][2] |
|
|
|
input_image = apply_overlay(out, paste_to, overlay_image) |
|
inpaint_images.append(input_image) |
|
|
|
return inpaint_images |
|
|
|
|
|
def get_crop_region(mask, pad=0): |
|
"""finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle. |
|
For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)""" |
|
|
|
h, w = mask.shape |
|
|
|
crop_left = 0 |
|
for i in range(w): |
|
if not (mask[:, i] == 0).all(): |
|
break |
|
crop_left += 1 |
|
|
|
crop_right = 0 |
|
for i in reversed(range(w)): |
|
if not (mask[:, i] == 0).all(): |
|
break |
|
crop_right += 1 |
|
|
|
crop_top = 0 |
|
for i in range(h): |
|
if not (mask[i] == 0).all(): |
|
break |
|
crop_top += 1 |
|
|
|
crop_bottom = 0 |
|
for i in reversed(range(h)): |
|
if not (mask[i] == 0).all(): |
|
break |
|
crop_bottom += 1 |
|
|
|
return ( |
|
int(max(crop_left-pad, 0)), |
|
int(max(crop_top-pad, 0)), |
|
int(min(w - crop_right + pad, w)), |
|
int(min(h - crop_bottom + pad, h)) |
|
) |
|
|
|
def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height): |
|
"""expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region |
|
for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128.""" |
|
|
|
x1, y1, x2, y2 = crop_region |
|
|
|
ratio_crop_region = (x2 - x1) / (y2 - y1) |
|
ratio_processing = processing_width / processing_height |
|
|
|
if ratio_crop_region > ratio_processing: |
|
desired_height = (x2 - x1) / ratio_processing |
|
desired_height_diff = int(desired_height - (y2-y1)) |
|
y1 -= desired_height_diff//2 |
|
y2 += desired_height_diff - desired_height_diff//2 |
|
if y2 >= image_height: |
|
diff = y2 - image_height |
|
y2 -= diff |
|
y1 -= diff |
|
if y1 < 0: |
|
y2 -= y1 |
|
y1 -= y1 |
|
if y2 >= image_height: |
|
y2 = image_height |
|
else: |
|
desired_width = (y2 - y1) * ratio_processing |
|
desired_width_diff = int(desired_width - (x2-x1)) |
|
x1 -= desired_width_diff//2 |
|
x2 += desired_width_diff - desired_width_diff//2 |
|
if x2 >= image_width: |
|
diff = x2 - image_width |
|
x2 -= diff |
|
x1 -= diff |
|
if x1 < 0: |
|
x2 -= x1 |
|
x1 -= x1 |
|
if x2 >= image_width: |
|
x2 = image_width |
|
|
|
return x1, y1, x2, y2 |
|
|
|
@dataclass |
|
class PredictOutput: |
|
bboxes: List[List[Union[int, float]]] = field(default_factory=list) |
|
masks: List[Image.Image] = field(default_factory=list) |
|
preview: Optional[Image.Image] = None |
|
|
|
def create_mask_from_bbox( |
|
bboxes: List[List[float]], shape: Tuple[int, int] |
|
) -> List[Image.Image]: |
|
""" |
|
Parameters |
|
---------- |
|
bboxes: List[List[float]] |
|
list of [x1, y1, x2, y2] |
|
bounding boxes |
|
shape: Tuple[int, int] |
|
shape of the image (width, height) |
|
|
|
Returns |
|
------- |
|
masks: List[Image.Image] |
|
A list of masks |
|
|
|
""" |
|
masks = [] |
|
for bbox in bboxes: |
|
mask = Image.new("L", shape, 0) |
|
mask_draw = ImageDraw.Draw(mask) |
|
mask_draw.rectangle(bbox, fill=255) |
|
masks.append(mask) |
|
return masks |
|
|
|
def ultralytics_predict( |
|
|
|
yolodector_model, |
|
image: Image.Image, |
|
confidence: float = 0.5, |
|
device: str = "cuda", |
|
) -> PredictOutput: |
|
|
|
bboxes, _ = yolodector_model.predict(np.array(image), conf_thres=confidence, iou_thres=0.5) |
|
masks = create_mask_from_bbox(bboxes[0], image.size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return PredictOutput(bboxes=bboxes[0], masks=masks, preview=image) |
|
|
|
def mask_to_pil(masks, shape: Tuple[int, int]) -> List[Image.Image]: |
|
""" |
|
Parameters |
|
---------- |
|
masks: torch.Tensor, dtype=torch.float32, shape=(N, H, W). |
|
The device can be CUDA, but `to_pil_image` takes care of that. |
|
|
|
shape: Tuple[int, int] |
|
(width, height) of the original image |
|
""" |
|
n = masks.shape[0] |
|
return [to_pil_image(masks[i], mode="L").resize(shape) for i in range(n)] |
|
|
|
class MergeInvert(IntEnum): |
|
NONE = 0 |
|
MERGE = 1 |
|
MERGE_INVERT = 2 |
|
|
|
def offset(img: Image.Image, x: int = 0, y: int = 0) -> Image.Image: |
|
""" |
|
The offset function takes an image and offsets it by a given x(→) and y(↑) value. |
|
|
|
Parameters |
|
---------- |
|
mask: Image.Image |
|
Pass the mask image to the function |
|
x: int |
|
→ |
|
y: int |
|
↑ |
|
|
|
Returns |
|
------- |
|
PIL.Image.Image |
|
A new image that is offset by x and y |
|
""" |
|
return ImageChops.offset(img, x, -y) |
|
|
|
|
|
def is_all_black(img: Image.Image) -> bool: |
|
arr = np.array(img) |
|
return cv2.countNonZero(arr) == 0 |
|
|
|
def _dilate(arr: np.ndarray, value: int) -> np.ndarray: |
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) |
|
return cv2.dilate(arr, kernel, iterations=1) |
|
|
|
|
|
def _erode(arr: np.ndarray, value: int) -> np.ndarray: |
|
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (value, value)) |
|
return cv2.erode(arr, kernel, iterations=1) |
|
|
|
def dilate_erode(img: Image.Image, value: int) -> Image.Image: |
|
""" |
|
The dilate_erode function takes an image and a value. |
|
If the value is positive, it dilates the image by that amount. |
|
If the value is negative, it erodes the image by that amount. |
|
|
|
Parameters |
|
---------- |
|
img: PIL.Image.Image |
|
the image to be processed |
|
value: int |
|
kernel size of dilation or erosion |
|
|
|
Returns |
|
------- |
|
PIL.Image.Image |
|
The image that has been dilated or eroded |
|
""" |
|
if value == 0: |
|
return img |
|
|
|
arr = np.array(img) |
|
arr = _dilate(arr, value) if value > 0 else _erode(arr, -value) |
|
|
|
return Image.fromarray(arr) |
|
|
|
def mask_preprocess( |
|
masks: List[Image.Image], |
|
kernel: int = 0, |
|
x_offset: int = 0, |
|
y_offset: int = 0, |
|
merge_invert: Union[int, 'MergeInvert', str] = MergeInvert.NONE, |
|
) -> List[Image.Image]: |
|
""" |
|
The mask_preprocess function takes a list of masks and preprocesses them. |
|
It dilates and erodes the masks, and offsets them by x_offset and y_offset. |
|
|
|
Parameters |
|
---------- |
|
masks: List[Image.Image] |
|
A list of masks |
|
kernel: int |
|
kernel size of dilation or erosion |
|
x_offset: int |
|
→ |
|
y_offset: int |
|
↑ |
|
|
|
Returns |
|
------- |
|
List[Image.Image] |
|
A list of processed masks |
|
""" |
|
if not masks: |
|
return [] |
|
|
|
if x_offset != 0 or y_offset != 0: |
|
masks = [offset(m, x_offset, y_offset) for m in masks] |
|
|
|
if kernel != 0: |
|
masks = [dilate_erode(m, kernel) for m in masks] |
|
masks = [m for m in masks if not is_all_black(m)] |
|
|
|
return mask_merge_invert(masks, mode=merge_invert) |
|
|
|
def mask_merge_invert( |
|
masks: List[Image.Image], mode: Union[int, 'MergeInvert', str] |
|
) -> List[Image.Image]: |
|
if isinstance(mode, str): |
|
mode = MASK_MERGE_INVERT.index(mode) |
|
|
|
if mode == MergeInvert.NONE or not masks: |
|
return masks |
|
|
|
if mode == MergeInvert.MERGE: |
|
return mask_merge(masks) |
|
|
|
if mode == MergeInvert.MERGE_INVERT: |
|
merged = mask_merge(masks) |
|
return mask_invert(merged) |
|
|
|
raise RuntimeError |
|
|
|
def bbox_area(bbox: List[float]): |
|
return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
|
|
|
def filter_by_ratio(pred: PredictOutput, low: float, high: float) -> PredictOutput: |
|
def is_in_ratio(bbox: List[float], low: float, high: float, orig_area: int) -> bool: |
|
area = bbox_area(bbox) |
|
return low <= area / orig_area <= high |
|
|
|
if not pred.bboxes: |
|
return pred |
|
|
|
w, h = pred.preview.size |
|
orig_area = w * h |
|
items = len(pred.bboxes) |
|
idx = [i for i in range(items) if is_in_ratio(pred.bboxes[i], low, high, orig_area)] |
|
pred.bboxes = [pred.bboxes[i] for i in idx] |
|
pred.masks = [pred.masks[i] for i in idx] |
|
return pred |
|
|
|
class SortBy(IntEnum): |
|
NONE = 0 |
|
LEFT_TO_RIGHT = 1 |
|
CENTER_TO_EDGE = 2 |
|
AREA = 3 |
|
|
|
|
|
def _key_left_to_right(bbox: List[float]) -> float: |
|
""" |
|
Left to right |
|
|
|
Parameters |
|
---------- |
|
bbox: list[float] |
|
list of [x1, y1, x2, y2] |
|
""" |
|
return bbox[0] |
|
|
|
|
|
def _key_center_to_edge(bbox: List[float], *, center: Tuple[float, float]) -> float: |
|
""" |
|
Center to edge |
|
|
|
Parameters |
|
---------- |
|
bbox: list[float] |
|
list of [x1, y1, x2, y2] |
|
image: Image.Image |
|
the image |
|
""" |
|
bbox_center = ((bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2) |
|
return dist(center, bbox_center) |
|
|
|
|
|
def _key_area(bbox: List[float]) -> float: |
|
""" |
|
Large to small |
|
|
|
Parameters |
|
---------- |
|
bbox: list[float] |
|
list of [x1, y1, x2, y2] |
|
""" |
|
return -bbox_area(bbox) |
|
|
|
def sort_bboxes( |
|
pred: PredictOutput, order: Union[int, 'SortBy'] = SortBy.NONE |
|
) -> PredictOutput: |
|
if order == SortBy.NONE or len(pred.bboxes) <= 1: |
|
return pred |
|
|
|
if order == SortBy.LEFT_TO_RIGHT: |
|
key = _key_left_to_right |
|
elif order == SortBy.CENTER_TO_EDGE: |
|
width, height = pred.preview.size |
|
center = (width / 2, height / 2) |
|
key = partial(_key_center_to_edge, center=center) |
|
elif order == SortBy.AREA: |
|
key = _key_area |
|
else: |
|
raise RuntimeError |
|
|
|
items = len(pred.bboxes) |
|
idx = sorted(range(items), key=lambda i: key(pred.bboxes[i])) |
|
pred.bboxes = [pred.bboxes[i] for i in idx] |
|
pred.masks = [pred.masks[i] for i in idx] |
|
return pred |
|
|
|
def filter_k_largest(pred: PredictOutput, k: int = 0) -> PredictOutput: |
|
if not pred.bboxes or k == 0: |
|
return pred |
|
areas = [bbox_area(bbox) for bbox in pred.bboxes] |
|
idx = np.argsort(areas)[-k:] |
|
pred.bboxes = [pred.bboxes[i] for i in idx] |
|
pred.masks = [pred.masks[i] for i in idx] |
|
return pred |
|
|
|
def pred_preprocessing(pred: PredictOutput) -> List[Image.Image]: |
|
pred = filter_by_ratio( |
|
pred, low=0.0, high=1.0 |
|
) |
|
pred = filter_k_largest(pred, k=0) |
|
pred = sort_bboxes(pred, SortBy.AREA) |
|
return mask_preprocess( |
|
pred.masks, |
|
kernel=4, |
|
x_offset=0, |
|
y_offset=0, |
|
merge_invert="None", |
|
) |
|
|
|
def apply_overlay(image, paste_loc, overlay): |
|
if overlay is None: |
|
return image |
|
|
|
if paste_loc is not None: |
|
x, y, w, h = paste_loc |
|
base_image = Image.new('RGBA', (overlay.width, overlay.height)) |
|
image = image.resize((w, h), Image.LANCZOS) |
|
base_image.paste(image, (x, y)) |
|
image = base_image |
|
|
|
image = image.convert('RGBA') |
|
image.alpha_composite(overlay) |
|
image = image.convert('RGB') |
|
|
|
return image |