test_gradio / image_gen_aux /utils /tiling_utils.py
amos1088's picture
uuu
a2919a7
from typing import Callable, Tuple
import torch
def create_gradient_mask(shape: Tuple, feather: int, device="cpu") -> torch.Tensor:
"""
Create a gradient mask for smooth blending of tiles.
Args:
shape (tuple): Shape of the mask (batch, channels, height, width)
feather (int): Width of the feathered edge
Returns:
torch.Tensor: Gradient mask
"""
mask = torch.ones(shape).to(device)
_, _, h, w = shape
for feather_step in range(feather):
factor = (feather_step + 1) / feather
mask[:, :, feather_step, :] *= factor
mask[:, :, h - 1 - feather_step, :] *= factor
mask[:, :, :, feather_step] *= factor
mask[:, :, :, w - 1 - feather_step] *= factor
return mask
def tiled_upscale(
samples: torch.Tensor,
function: Callable,
scale: int,
tile_width: int = 512,
tile_height: int = 512,
overlap: int = 8,
) -> torch.Tensor:
"""
Apply a scaling function to image samples in a tiled manner.
Args:
samples (torch.Tensor): Input tensor of shape (batch_size, channels, height, width)
function (Callable): The scaling function to apply to each tile
scale (int): Factor by which to upscale the image
tile_width (int): Width of each tile
tile_height (int): Height of each tile
overlap (int): Overlap between tiles
Returns:
torch.Tensor: Upscaled and processed output tensor
"""
_batch, _channels, height, width = samples.shape
out_height, out_width = round(height * scale), round(width * scale)
output_device = samples.device
# Initialize output tensors
output = torch.empty((1, 3, out_height, out_width), device=output_device)
out = torch.zeros((1, 3, out_height, out_width), device=output_device)
out_div = torch.zeros_like(output)
# Process the image in tiles
for y in range(0, height, tile_height - overlap):
for x in range(0, width, tile_width - overlap):
# Ensure we don't go out of bounds
x_end = min(x + tile_width, width)
y_end = min(y + tile_height, height)
x = max(0, x_end - tile_width)
y = max(0, y_end - tile_height)
# Extract and process the tile
tile = samples[:, :, y:y_end, x:x_end]
processed_tile = function(tile).to(output_device)
# Calculate the position in the output tensor
out_y, out_x = round(y * scale), round(x * scale)
out_h, out_w = processed_tile.shape[2:]
# Create a feathered mask for smooth blending
mask = create_gradient_mask(processed_tile.shape, overlap * scale, device=output_device)
# Add the processed tile to the output
out[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += processed_tile * mask
out_div[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += mask
# Normalize the output
output = out / out_div
return output