File size: 19,832 Bytes
241b6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Thread
from typing import Callable, Generator, Optional, Union

import numpy as np
import torch
from rasterio.profiles import Profile
from tqdm.auto import tqdm

from .__version__ import __version__
from .download_models import get_models
from .model_utils import (
    create_gradient_mask,
    default_device,
    get_torch_dtype,
    inference_and_store,
    load_model_from_weights,
)
from .raster_utils import (
    get_patch,
    make_patch_indexes,
    mask_prediction,
    save_prediction,
)


def compile_batches(
    batch_size: int,
    patch_size: int,
    patch_indexes: list[tuple[int, int, int, int]],
    input_array: np.ndarray,
    no_data_value: int,
    inference_device: torch.device,
    inference_dtype: torch.dtype,
) -> Generator[tuple[torch.Tensor, list[tuple[int, int, int, int]]], None, None]:
    """Used to compile batches of patches from the input array and return them as a generator."""

    with ThreadPoolExecutor(max_workers=batch_size) as executor:
        futures = [
            executor.submit(get_patch, input_array, index, no_data_value)
            for index in patch_indexes
        ]

        total_futures = len(futures)
        all_indexes = set()
        index_batch = []
        patch_batch_array = np.zeros(
            (batch_size, input_array.shape[0], patch_size, patch_size), dtype=np.float32
        )

        for index, future in enumerate(as_completed(futures)):
            patch, new_index = future.result()

            if patch is not None and new_index not in all_indexes:
                index_batch.append(new_index)
                patch_batch_array[len(index_batch) - 1] = patch
                all_indexes.add(new_index)

            if len(index_batch) == batch_size or index == total_futures - 1:
                if len(index_batch) == 0:
                    continue
                input_tensor = (
                    torch.tensor(patch_batch_array[: len(index_batch)])
                    .to(inference_device)
                    .to(inference_dtype)
                )
                yield input_tensor, index_batch
                index_batch = []


def run_models_on_array(
    models: list[torch.nn.Module],
    input_array: np.ndarray,
    pred_tracker: torch.Tensor,
    grad_tracker: Union[torch.Tensor, None],
    patch_size: int,
    patch_overlap: int,
    inference_device: torch.device,
    batch_size: int = 2,
    inference_dtype: torch.dtype = torch.float32,
    no_data_value: int = 0,
) -> None:
    """Used to execute the model on the input array, in patches. Predictions are stored in pred_tracker and grad_tracker, updated in place."""
    patch_indexes = make_patch_indexes(
        array_height=input_array.shape[1],
        array_width=input_array.shape[2],
        patch_size=patch_size,
        patch_overlap=patch_overlap,
    )

    gradient = create_gradient_mask(
        patch_size, patch_overlap, device=inference_device, dtype=inference_dtype
    )

    input_tensor_gen = compile_batches(
        batch_size=batch_size,
        patch_size=patch_size,
        patch_indexes=patch_indexes,
        input_array=input_array,
        no_data_value=no_data_value,
        inference_device=inference_device,
        inference_dtype=inference_dtype,
    )

    for patch_batch, index_batch in input_tensor_gen:
        inference_and_store(
            models=models,
            patch_batch=patch_batch,
            index_batch=index_batch,
            pred_tracker=pred_tracker,
            gradient=gradient,
            grad_tracker=grad_tracker,
        )


def check_patch_size(
    input_array: np.ndarray, no_data_value: int, patch_size: int, patch_overlap: int
) -> tuple[int, int]:
    """Used to check the inputs and adjust the patch size and overlap if necessary."""
    # check the shape of the input array
    if len(input_array.shape) != 3:
        raise ValueError(
            f"Input array must have 3 dimensions, found {len(input_array.shape)}. The input should be in format (bands (red,green,NIR), height, width)."
        )

    # check the width and height are greater than 10 pixels
    if min(input_array.shape[1], input_array.shape[2]) < 10:
        raise ValueError(
            f"Input array must have a width and height greater than 10 pixels, found shape {input_array.shape}. The input should be in format (bands (red,green,NIR), height, width)."
        )
    if min(input_array.shape[1], input_array.shape[2]) < 50:
        warnings.warn(
            f"Input width or height is less than 50 pixels, found shape {input_array.shape}. Such a small image may not provide adequate spatial context for the model."
        )

    # if the input has a lot of no data values and the patch size is larger than half the image size, we reduce the patch size and overlap
    if np.count_nonzero(input_array == no_data_value) / input_array.size > 0.3:
        if patch_size > min(input_array.shape[1], input_array.shape[2]) / 2:
            patch_size = min(input_array.shape[1], input_array.shape[2]) // 2
            if patch_size // 2 < patch_overlap:
                patch_overlap = patch_size // 2

            warnings.warn(
                f"Significant no-data areas detected. Adjusting patch size to {patch_size}px and overlap to {patch_overlap}px to minimize no-data patches."
            )

    # if the patch size is larger than the image size, we reduce the patch size and overlap
    if patch_size > min(input_array.shape[1], input_array.shape[2]):
        patch_size = min(input_array.shape[1], input_array.shape[2])
        if patch_size // 2 < patch_overlap:
            patch_overlap = patch_size // 2
        warnings.warn(
            f"Patch size too large, reducing to {patch_size} and overlap to {patch_overlap}."
        )

    # if the patch overlap is larger than the patch size, raise an error
    if patch_overlap >= patch_size:
        raise ValueError(
            f"Patch overlap {patch_overlap}px must be less than patch size {patch_size}px."
        )
    return patch_overlap, patch_size


def coordinator(
    input_array: np.ndarray,
    models: list[torch.nn.Module],
    inference_dtype: torch.dtype,
    export_confidence: bool,
    softmax_output: bool,
    inference_device: torch.device,
    mosaic_device: torch.device,
    patch_size: int,
    patch_overlap: int,
    batch_size: int,
    profile: Profile = Profile(),
    output_path: Path = Path(""),
    no_data_value: int = 0,
    pbar: Optional[tqdm] = None,
    apply_no_data_mask: bool = False,
    export_to_disk: bool = True,
    save_executor: Optional[ThreadPoolExecutor] = None,
    pred_classes: int = 4,
) -> np.ndarray:
    """Used to coordinate the process of predicting from an input array."""

    patch_overlap, patch_size = check_patch_size(
        input_array, no_data_value, patch_size, patch_overlap
    )

    pred_tracker = torch.zeros(
        (pred_classes, *input_array.shape[1:3]),
        dtype=inference_dtype,
        device=mosaic_device,
    )

    grad_tracker = (
        torch.zeros(input_array.shape[1:3], dtype=inference_dtype, device=mosaic_device)
        if export_confidence
        else None
    )

    run_models_on_array(
        models=models,
        input_array=input_array,
        pred_tracker=pred_tracker,
        grad_tracker=grad_tracker,
        inference_device=inference_device,
        inference_dtype=inference_dtype,
        no_data_value=no_data_value,
        patch_size=patch_size,
        patch_overlap=patch_overlap,
        batch_size=batch_size,
    )

    if export_confidence:
        pred_tracker_norm = pred_tracker / grad_tracker
        if softmax_output:
            pred_tracker = torch.clip(
                (torch.nn.functional.softmax(pred_tracker_norm, 0) + 0.001),
                0.001,
                0.999,
            )
        else:
            pred_tracker = pred_tracker_norm

        pred_tracker_np = pred_tracker.float().numpy(force=True)

    else:
        pred_tracker_np = (
            torch.argmax(pred_tracker, 0, keepdim=True)
            .numpy(force=True)
            .astype(np.uint8)
        )

    if apply_no_data_mask:
        pred_tracker_np = mask_prediction(input_array, pred_tracker_np, no_data_value)

    if export_to_disk:
        export_profile = profile.copy()
        export_profile.update(
            dtype=pred_tracker_np.dtype,
            count=pred_tracker_np.shape[0],
            compress="lzw",
            nodata=0,
            driver="GTiff",
        )
        # if executer has been passed, submit the save_prediction function to it, to avoid blocking the main thread
        if save_executor:
            save_executor.submit(
                save_prediction, output_path, export_profile, pred_tracker_np
            )
        # otherwise save the prediction directly

        else:
            save_prediction(output_path, export_profile, pred_tracker_np)

    if pbar:
        pbar.update(1)
    return pred_tracker_np


def collect_models(
    custom_models: Union[list[torch.nn.Module], torch.nn.Module],
    inference_device: torch.device,
    inference_dtype: torch.dtype,
    source: str,
    destination_model_dir: Union[str, Path, None] = None,
) -> list[torch.nn.Module]:
    if not custom_models:
        models = []
        for model_details in get_models(model_dir=destination_model_dir, source=source):
            models.append(
                load_model_from_weights(
                    model_name=model_details["timm_model_name"],
                    weights_path=model_details["Path"],
                    device=inference_device,
                    dtype=inference_dtype,
                )
            )
    else:
        # if not a list, make it a list of models
        if not isinstance(custom_models, list):
            custom_models = [custom_models]

        models = [
            model.to(inference_dtype).to(inference_device) for model in custom_models
        ]
    return models


def predict_from_array(
    input_array: np.ndarray,
    patch_size: int = 1000,
    patch_overlap: int = 300,
    batch_size: int = 1,
    inference_device: Union[str, torch.device] = default_device(),
    mosaic_device: Optional[Union[str, torch.device]] = None,
    inference_dtype: Union[torch.dtype, str] = torch.float32,
    export_confidence: bool = False,
    softmax_output: bool = True,
    no_data_value: int = 0,
    apply_no_data_mask: bool = True,
    custom_models: Union[list[torch.nn.Module], torch.nn.Module] = [],
    pred_classes: int = 4,
    destination_model_dir: Union[str, Path, None] = None,
    model_download_source: str = "google_drive",
) -> np.ndarray:
    """Predict a cloud and cloud shadow mask from a Red, Green and NIR numpy array, with a spatial res between 10 m and 50 m.

    Args:
        input_array (np.ndarray): A numpy array with shape (3, height, width) representing the Red, Green and NIR bands.
        patch_size (int, optional): Size of the patches for inference. Defaults to 1000.
        patch_overlap (int, optional): Overlap between patches for inference. Defaults to 300.
        batch_size (int, optional): Number of patches to process in a batch. Defaults to 1.
        inference_device (Union[str, torch.device], optional): Device to use for inference (e.g., 'cpu', 'cuda', 'mps'). Defaults to the device returned by default_device().
        mosaic_device (Union[str, torch.device], optional): Device to use for mosaicking patches. Defaults to inference device.
        inference_dtype (Union[torch.dtype, str], optional): Data type for inference. Defaults to torch.float32.
        export_confidence (bool, optional): If True, exports confidence maps instead of predicted classes. Defaults to False.
        softmax_output (bool, optional): If True, applies a softmax to the output, only used if export_confidence = True. Defaults to True.
        no_data_value (int, optional): Value within input scenes that specifies no data region. Defaults to 0.
        apply_no_data_mask (bool, optional): If True, applies a no-data mask to the predictions. Defaults to True.
        custom_models Union[list[torch.nn.Module], torch.nn.Module], optional): A list or singular custom torch models to use for prediction. Defaults to [].
        pred_classes (int, optional): Number of classes to predict. Defaults to 4, to be used with custom models.
        destination_model_dir Union[str, Path, None]: Directory to save the model weights. Defaults to None.
        model_download_source (str, optional): Source from which to download the model weights. Defaults to "google_drive", can also be "hugging_face".
    Returns:
        np.ndarray: A numpy array with shape (1, height, width) or (4, height, width if export_confidence = True) representing the predicted cloud and cloud shadow mask.

    """

    inference_device = torch.device(inference_device)
    if mosaic_device is None:
        mosaic_device = inference_device
    else:
        mosaic_device = torch.device(mosaic_device)

    inference_dtype = get_torch_dtype(inference_dtype)
    # if no custom model paths are provided, use the default models
    models = collect_models(
        custom_models=custom_models,
        inference_device=inference_device,
        inference_dtype=inference_dtype,
        source=model_download_source,
        destination_model_dir=destination_model_dir,
    )

    pred_tracker = coordinator(
        input_array=input_array,
        models=models,
        inference_device=inference_device,
        mosaic_device=mosaic_device,
        inference_dtype=inference_dtype,
        export_confidence=export_confidence,
        softmax_output=softmax_output,
        patch_size=patch_size,
        patch_overlap=patch_overlap,
        batch_size=batch_size,
        no_data_value=no_data_value,
        export_to_disk=False,
        apply_no_data_mask=apply_no_data_mask,
        pred_classes=pred_classes,
    )

    return pred_tracker


def predict_from_load_func(
    scene_paths: Union[list[Path], list[str]],
    load_func: Callable,
    patch_size: int = 1000,
    patch_overlap: int = 300,
    batch_size: int = 1,
    inference_device: Union[str, torch.device] = default_device(),
    mosaic_device: Optional[Union[str, torch.device]] = None,
    inference_dtype: Union[torch.dtype, str] = torch.float32,
    export_confidence: bool = False,
    softmax_output: bool = True,
    no_data_value: int = 0,
    overwrite: bool = True,
    apply_no_data_mask: bool = True,
    output_dir: Optional[Union[Path, str]] = None,
    custom_models: Union[list[torch.nn.Module], torch.nn.Module] = [],
    destination_model_dir: Union[str, Path, None] = None,
    model_download_source: str = "google_drive",
) -> list[Path]:
    """
    Predicts cloud and cloud shadow masks for a list of scenes using a specified loading function.

    Args:
        scene_paths (Union[list[Path], list[str]]): A list of paths to the scene files to be processed.
        load_func (Callable): A function to load the scene data. This function should take an input_path parameter and return a R,G,NIR numpy array and a rasterio for export profile, several load func are provided within data_loaders.py
        patch_size (int, optional): Size of the patches for inference. Defaults to 1000.
        patch_overlap (int, optional): Overlap between patches for inference. Defaults to 300.
        batch_size (int, optional): Number of patches to process in a batch. Defaults to 1.
        inference_device (Union[str, torch.device], optional): Device to use for inference (e.g., 'cpu', 'cuda', 'mps'). Defaults to the device returned by default_device().
        mosaic_device (Union[str, torch.device], optional): Device to use for mosaicking patches. Defaults to inference device.
        inference_dtype (Union[torch.dtype, str], optional): Data type for inference. Defaults to torch.float32.
        export_confidence (bool, optional): If True, exports confidence maps instead of predicted classes. Defaults to False.
        softmax_output (bool, optional): If True, applies a softmax to the output, only used if export_confidence = True. Defaults to True.
        no_data_value (int, optional): Value within input scenes that specifies no data region. Defaults to 0.
        overwrite (bool, optional): If False, skips scenes that already have a prediction file. Defaults to True.
        apply_no_data_mask (bool, optional): If True, applies a no-data mask to the predictions. Defaults to True.
        output_dir (Optional[Union[Path, str]], optional): Directory to save the prediction files. Defaults to None. If None, the predictions will be saved in the same directory as the input scene.
        custom_models Union[list[torch.nn.Module], torch.nn.Module], optional): A list or singular custom torch models to use for prediction. Defaults to [].
        destination_model_dir Union[str, Path, None]: Directory to save the model weights. Defaults to None.
        model_download_source (str, optional): Source from which to download the model weights. Defaults to "google_drive", can also be "hugging_face".

    Returns:
        list[Path]: A list of paths to the output prediction files.

    """
    pred_paths = []
    inf_thread = Thread()
    save_executor = ThreadPoolExecutor(max_workers=1)

    inference_device = torch.device(inference_device)
    if mosaic_device is None:
        mosaic_device = inference_device
    else:
        mosaic_device = torch.device(mosaic_device)

    inference_dtype = get_torch_dtype(inference_dtype)

    models = collect_models(
        custom_models=custom_models,
        inference_device=inference_device,
        inference_dtype=inference_dtype,
        destination_model_dir=destination_model_dir,
        source=model_download_source,
    )

    pbar = tqdm(
        total=len(scene_paths),
        desc=f"Running inference using {inference_device.type} {str(inference_dtype).split('.')[-1]}",
    )

    for scene_path in scene_paths:
        scene_path = Path(scene_path)
        file_name = f"{scene_path.stem}_OCM_v{__version__.replace('.','_')}.tif"

        if output_dir is None:
            output_path = scene_path.parent / file_name
        else:
            Path(output_dir).mkdir(parents=True, exist_ok=True)
            output_path = Path(output_dir) / file_name

        pred_paths.append(output_path)

        if output_path.exists() and not overwrite:
            pbar.update(1)
            pbar.refresh()
            continue

        input_array, profile = load_func(input_path=scene_path)

        while inf_thread.is_alive():
            inf_thread.join()

        inf_thread = Thread(
            target=coordinator,
            kwargs={
                "input_array": input_array,
                "profile": profile,
                "output_path": output_path,
                "models": models,
                "inference_dtype": inference_dtype,
                "export_confidence": export_confidence,
                "softmax_output": softmax_output,
                "inference_device": inference_device,
                "mosaic_device": mosaic_device,
                "patch_size": patch_size,
                "patch_overlap": patch_overlap,
                "batch_size": batch_size,
                "no_data_value": no_data_value,
                "pbar": pbar,
                "apply_no_data_mask": apply_no_data_mask,
                "save_executor": save_executor,
            },
        )
        inf_thread.start()

    while inf_thread.is_alive():
        inf_thread.join()

    if inference_device.type.startswith("cuda"):
        torch.cuda.empty_cache()

    save_executor.shutdown(wait=True)
    pbar.refresh()

    return pred_paths