File size: 3,015 Bytes
a2919a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import Callable, Tuple

import torch


def create_gradient_mask(shape: Tuple, feather: int, device="cpu") -> torch.Tensor:
    """
    Create a gradient mask for smooth blending of tiles.

    Args:
        shape (tuple): Shape of the mask (batch, channels, height, width)
        feather (int): Width of the feathered edge

    Returns:
        torch.Tensor: Gradient mask
    """
    mask = torch.ones(shape).to(device)
    _, _, h, w = shape
    for feather_step in range(feather):
        factor = (feather_step + 1) / feather
        mask[:, :, feather_step, :] *= factor
        mask[:, :, h - 1 - feather_step, :] *= factor
        mask[:, :, :, feather_step] *= factor
        mask[:, :, :, w - 1 - feather_step] *= factor
    return mask


def tiled_upscale(
    samples: torch.Tensor,
    function: Callable,
    scale: int,
    tile_width: int = 512,
    tile_height: int = 512,
    overlap: int = 8,
) -> torch.Tensor:
    """
    Apply a scaling function to image samples in a tiled manner.

    Args:
        samples (torch.Tensor): Input tensor of shape (batch_size, channels, height, width)
        function (Callable): The scaling function to apply to each tile
        scale (int): Factor by which to upscale the image
        tile_width (int): Width of each tile
        tile_height (int): Height of each tile
        overlap (int): Overlap between tiles

    Returns:
        torch.Tensor: Upscaled and processed output tensor
    """
    _batch, _channels, height, width = samples.shape
    out_height, out_width = round(height * scale), round(width * scale)
    output_device = samples.device

    # Initialize output tensors
    output = torch.empty((1, 3, out_height, out_width), device=output_device)
    out = torch.zeros((1, 3, out_height, out_width), device=output_device)
    out_div = torch.zeros_like(output)

    # Process the image in tiles
    for y in range(0, height, tile_height - overlap):
        for x in range(0, width, tile_width - overlap):
            # Ensure we don't go out of bounds
            x_end = min(x + tile_width, width)
            y_end = min(y + tile_height, height)
            x = max(0, x_end - tile_width)
            y = max(0, y_end - tile_height)

            # Extract and process the tile
            tile = samples[:, :, y:y_end, x:x_end]
            processed_tile = function(tile).to(output_device)

            # Calculate the position in the output tensor
            out_y, out_x = round(y * scale), round(x * scale)
            out_h, out_w = processed_tile.shape[2:]

            # Create a feathered mask for smooth blending
            mask = create_gradient_mask(processed_tile.shape, overlap * scale, device=output_device)

            # Add the processed tile to the output
            out[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += processed_tile * mask
            out_div[:, :, out_y : out_y + out_h, out_x : out_x + out_w] += mask

    # Normalize the output
    output = out / out_div

    return output