Spaces:
Running
Running
File size: 6,122 Bytes
613c9ab |
|
from PIL import Image, ImageFilter
import torch
import math
from nodes import common_ksampler, VAEEncode, VAEDecode, VAEDecodeTiled
from utils import pil_to_tensor, tensor_to_pil, get_crop_region, expand_crop, crop_cond
from modules import shared
if (not hasattr(Image, 'Resampling')): # For older versions of Pillow
Image.Resampling = Image
class StableDiffusionProcessing:
def __init__(self, init_img, model, positive, negative, vae, seed, steps, cfg, sampler_name, scheduler, denoise, upscale_by, uniform_tile_mode, tiled_decode):
# Variables used by the USDU script
self.init_images = [init_img]
self.image_mask = None
self.mask_blur = 0
self.inpaint_full_res_padding = 0
self.width = init_img.width
self.height = init_img.height
# ComfyUI Sampler inputs
self.model = model
self.positive = positive
self.negative = negative
self.vae = vae
self.seed = seed
self.steps = steps
self.cfg = cfg
self.sampler_name = sampler_name
self.scheduler = scheduler
self.denoise = denoise
# Variables used only by this script
self.init_size = init_img.width, init_img.height
self.upscale_by = upscale_by
self.uniform_tile_mode = uniform_tile_mode
self.tiled_decode = tiled_decode
self.vae_decoder = VAEDecode()
self.vae_encoder = VAEEncode()
self.vae_decoder_tiled = VAEDecodeTiled()
# Other required A1111 variables for the USDU script that is currently unused in this script
self.extra_generation_params = {}
class Processed:
def __init__(self, p: StableDiffusionProcessing, images: list, seed: int, info: str):
self.images = images
self.seed = seed
self.info = info
def infotext(self, p: StableDiffusionProcessing, index):
return None
def fix_seed(p: StableDiffusionProcessing):
pass
def process_images(p: StableDiffusionProcessing) -> Processed:
# Where the main image generation happens in A1111
# Setup
image_mask = p.image_mask.convert('L')
init_image = p.init_images[0]
# Locate the white region of the mask outlining the tile and add padding
crop_region = get_crop_region(image_mask, p.inpaint_full_res_padding)
if p.uniform_tile_mode:
# Expand the crop region to match the processing size ratio and then resize it to the processing size
x1, y1, x2, y2 = crop_region
crop_width = x2 - x1
crop_height = y2 - y1
crop_ratio = crop_width / crop_height
p_ratio = p.width / p.height
if crop_ratio > p_ratio:
target_width = crop_width
target_height = round(crop_width / p_ratio)
else:
target_width = round(crop_height * p_ratio)
target_height = crop_height
crop_region, _ = expand_crop(crop_region, image_mask.width, image_mask.height, target_width, target_height)
tile_size = p.width, p.height
else:
# Uses the minimal size that can fit the mask, minimizes tile size but may lead to image sizes that the model is not trained on
x1, y1, x2, y2 = crop_region
crop_width = x2 - x1
crop_height = y2 - y1
target_width = math.ceil(crop_width / 8) * 8
target_height = math.ceil(crop_height / 8) * 8
crop_region, tile_size = expand_crop(crop_region, image_mask.width,
image_mask.height, target_width, target_height)
# Blur the mask
if p.mask_blur > 0:
image_mask = image_mask.filter(ImageFilter.GaussianBlur(p.mask_blur))
# Crop the images to get the tiles that will be used for generation
tiles = [img.crop(crop_region) for img in shared.batch]
# Assume the same size for all images in the batch
initial_tile_size = tiles[0].size
# Resize if necessary
for i, tile in enumerate(tiles):
if tile.size != tile_size:
tiles[i] = tile.resize(tile_size, Image.Resampling.LANCZOS)
# Crop conditioning
positive_cropped = crop_cond(p.positive, crop_region, p.init_size, init_image.size, tile_size)
negative_cropped = crop_cond(p.negative, crop_region, p.init_size, init_image.size, tile_size)
# Encode the image
batched_tiles = torch.cat([pil_to_tensor(tile) for tile in tiles], dim=0)
(latent,) = p.vae_encoder.encode(p.vae, batched_tiles)
# Generate samples
(samples,) = common_ksampler(p.model, p.seed, p.steps, p.cfg, p.sampler_name,
p.scheduler, positive_cropped, negative_cropped, latent, denoise=p.denoise)
# Decode the sample
if not p.tiled_decode:
(decoded,) = p.vae_decoder.decode(p.vae, samples)
else:
print("[USDU] Using tiled decode")
(decoded,) = p.vae_decoder_tiled.decode(p.vae, samples, 512) # Default tile size is 512
# Convert the sample to a PIL image
tiles_sampled = [tensor_to_pil(decoded, i) for i in range(len(decoded))]
for i, tile_sampled in enumerate(tiles_sampled):
init_image = shared.batch[i]
# Resize back to the original size
if tile_sampled.size != initial_tile_size:
tile_sampled = tile_sampled.resize(initial_tile_size, Image.Resampling.LANCZOS)
# Put the tile into position
image_tile_only = Image.new('RGBA', init_image.size)
image_tile_only.paste(tile_sampled, crop_region[:2])
# Add the mask as an alpha channel
# Must make a copy due to the possibility of an edge becoming black
temp = image_tile_only.copy()
temp.putalpha(image_mask)
image_tile_only.paste(temp, image_tile_only)
# Add back the tile to the initial image according to the mask in the alpha channel
result = init_image.convert('RGBA')
result.alpha_composite(image_tile_only)
# Convert back to RGB
result = result.convert('RGB')
shared.batch[i] = result
processed = Processed(p, [shared.batch[0]], p.seed, None)
return processed
|