File size: 4,113 Bytes
241b6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from pathlib import Path
from typing import Optional

import numpy as np
import rasterio as rio
from rasterio.profiles import Profile

from .model_utils import channel_norm


def get_patch(
    input_array: np.ndarray,
    index: tuple,
    no_data_value: Optional[int] = 0,
) -> tuple[Optional[np.ndarray], Optional[tuple[int, int, int, int]]]:
    """Extract a patch from a 3D array and normalize it. If the patch is entirely nodata, return None.
    If the patch contains nodata, try to move patches to reduce nodata regions in patches.
    """
    assert input_array.ndim == 3, "Input array must have 3 dimensions"

    top, bottom, left, right = index
    patch = input_array[:, top:bottom, left:right].astype(np.float32)

    if patch.sum() == 0:
        return None, None

    if no_data_value is None:
        if np.all(patch == no_data_value):
            return None, None

    if np.any(patch == 0):
        max_bottom, max_right = input_array.shape[1:3]

        if np.any(patch[:, 0, :]) or np.any(patch[:, -1, :]):
            while not np.any(patch[:, 0, :]) and bottom < max_bottom:  # check top row
                patch = patch[:, 1:, :]
                top += 1
                bottom += 1

            while not np.any(patch[:, -1, :]) and top > 0:
                patch = patch[:, :-1, :]
                bottom -= 1
                top -= 1

        # Both sides are not zero-filled
        if np.any(patch[:, :, 0]) or np.any(patch[:, :, -1]):
            while not np.any(patch[:, :, 0]) and right < max_right:  # check left column
                patch = patch[:, :, 1:]
                left += 1
                right += 1

            while not np.any(patch[:, :, -1]) and left > 0:  # check right column
                patch = patch[:, :, :-1]
                right -= 1
                left -= 1
        patch = input_array[:, top:bottom, left:right].astype(np.float32)
        index = (top, bottom, left, right)

    # trim index bottom and right to match patch shape
    index = (top, top + patch.shape[1], left, left + patch.shape[2])
    return channel_norm(patch, no_data_value), index


def mask_prediction(
    scene: np.ndarray, pred_tracker_np: np.ndarray, no_data_value: int = 0
) -> np.ndarray:
    """Create a no data mask from a raster scene."""
    assert scene.ndim == 3, "Scene must have 3 dimensions"
    assert pred_tracker_np.ndim == 3, "Prediction tracker must have 3 dimensions"
    assert (
        scene.shape[1:] == pred_tracker_np.shape[1:]
    ), "Scene and prediction tracker must have the same shape"
    mask = np.all(scene != no_data_value, axis=0).astype(np.uint8)
    pred_tracker_np *= mask
    return pred_tracker_np


def make_patch_indexes(
    array_width: int,
    array_height: int,
    patch_size: int = 1000,
    patch_overlap: int = 300,
) -> list[tuple[int, int, int, int]]:
    """Create a list of patch indexes for a given shape and patch size."""
    assert patch_size > patch_overlap, "Patch size must be greater than patch overlap"
    assert patch_overlap >= 0, "Patch overlap must be greater than or equal to 0"
    assert patch_size > 0, "Patch size must be greater than 0"
    assert (
        patch_size <= array_width
    ), "Patch size must be less than or equal to array width"
    assert (
        patch_size <= array_height
    ), "Patch size must be less than or equal to array height"

    stride = patch_size - patch_overlap

    max_bottom = array_height - patch_size
    max_right = array_width - patch_size

    patch_indexes = []
    for top in range(0, array_height, stride):
        if top > max_bottom:
            top = max_bottom
        bottom = top + patch_size
        for left in range(0, array_width, stride):
            if left > max_right:
                left = max_right
            right = left + patch_size
            patch_indexes.append((top, bottom, left, right))

    return patch_indexes


def save_prediction(
    output_path: Path, export_profile: Profile, pred_tracker_np: np.ndarray
) -> None:
    with rio.open(output_path, "w", **export_profile) as dst:
        dst.write(pred_tracker_np)