Spaces:
Paused
Paused
from typing import Callable, Tuple | |
import torch | |
def create_gradient_mask(shape: Tuple, feather: int, device="cpu") -> torch.Tensor: | |
""" | |
Create a gradient mask for smooth blending of tiles. | |
Args: | |
shape (tuple): Shape of the mask (batch, channels, height, width) | |
feather (int): Width of the feathered edge | |
Returns: | |
torch.Tensor: Gradient mask | |
""" | |
mask = torch.ones(shape).to(device) | |
_, _, h, w = shape | |
for feather_step in range(feather): | |
factor = (feather_step + 1) / feather | |
mask[:, :, feather_step, :] *= factor | |
mask[:, :, h - 1 - feather_step, :] *= factor | |
mask[:, :, :, feather_step] *= factor | |
mask[:, :, :, w - 1 - feather_step] *= factor | |
return mask | |
def tiled_upscale( | |
samples: torch.Tensor, | |
function: Callable, | |
scale: int, | |
tile_width: int = 512, | |
tile_height: int = 512, | |
overlap: int = 8, | |
) -> torch.Tensor: | |
""" | |
Apply a scaling function to image samples in a tiled manner. | |
Args: | |
samples (torch.Tensor): Input tensor of shape (batch_size, channels, height, width) | |
function (Callable): The scaling function to apply to each tile | |
scale (int): Factor by which to upscale the image | |
tile_width (int): Width of each tile | |
tile_height (int): Height of each tile | |
overlap (int): Overlap between tiles | |
Returns: | |
torch.Tensor: Upscaled and processed output tensor | |
""" | |
_batch, _channels, height, width = samples.shape | |
out_height, out_width = round(height * scale), round(width * scale) | |
output_device = samples.device | |
# Initialize output tensors | |
output = torch.empty((1, 3, out_height, out_width), device=output_device) | |
out = torch.zeros((1, 3, out_height, out_width), device=output_device) | |
out_div = torch.zeros_like(output) | |
# Process the image in tiles | |
for y in range(0, height, tile_height - overlap): | |
for x in range(0, width, tile_width - overlap): | |
# Ensure we don't go out of bounds | |
x_end = min(x + tile_width, width) | |
y_end = min(y + tile_height, height) | |
x = max(0, x_end - tile_width) | |
y = max(0, y_end - tile_height) | |
# Extract and process the tile | |
tile = samples[:, :, y:y_end, x:x_end] | |
processed_tile = function(tile).to(output_device) | |
# Calculate the position in the output tensor | |
out_y, out_x = round(y * scale), round(x * scale) | |
out_h, out_w = processed_tile.shape[2:] | |
# Create a feathered mask for smooth blending | |
mask = create_gradient_mask(processed_tile.shape, overlap * scale, device=output_device) | |
# Add the processed tile to the output | |
out[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += processed_tile * mask | |
out_div[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += mask | |
# Normalize the output | |
output = out / out_div | |
return output | |