File size: 4,958 Bytes
241b6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import warnings
from pathlib import Path
from typing import Optional, Union

import numpy as np
import rasterio as rio
from rasterio.profiles import Profile


def load_s2(
    input_path: Union[Path, str],
    resolution: float = 10.0,
    required_bands: list[str] = ["B04", "B03", "B8A"],
) -> tuple[np.ndarray, Profile]:
    """Load a Sentinel-2 (L1C or L2A) image from a SAFE folder containing the bands"""
    if not 10 <= resolution <= 50:
        raise ValueError("Resolution must be between 10 and 50")
    input_path = Path(input_path)
    processing_level = find_s2_processing_level(input_path)
    return open_s2_bands(input_path, processing_level, resolution, required_bands)


def find_s2_processing_level(
    input_path: Path,
) -> str:
    """Derive the processing level of a Sentinel-2 image from the folder name."""

    folder_name = Path(input_path).name
    processing_level = folder_name.split("_")[1][3:6]

    if processing_level not in ["L1C", "L2A"]:
        raise ValueError(
            f"Processing level {processing_level} not recognized, expected L1C or L2A"
        )
    return processing_level


def open_s2_bands(
    input_path: Path,
    processing_level: str,
    resolution: float,
    required_bands: list[str],
) -> tuple[np.ndarray, Profile]:
    bands = []
    for band_name in required_bands:
        if processing_level == "L1C":
            try:
                band = list(input_path.rglob(f"*IMG_DATA/*{band_name}.jp2"))[0]

            except IndexError:
                raise ValueError(f"Band {band_name} not found in {input_path}")
        else:
            band = None
            for search_resolution in [10, 20, 60]:
                band_paths = list(
                    input_path.rglob(f"*{band_name}_{search_resolution}m.jp2")
                )
                if band_paths:
                    band = band_paths[0]
                    break
            if not band:
                raise ValueError(f"Band {band_name} not found in {input_path}")

        with rio.open(band) as src:
            profile = src.profile
            native_resolution = int(src.res[0])
            scale_factor = native_resolution / resolution
            if native_resolution == resolution:
                bands.append(src.read(1))
            else:
                bands.append(
                    src.read(
                        1,
                        out_shape=(
                            int(src.height * scale_factor),
                            int(src.width * scale_factor),
                        ),
                    )
                )
    profile["transform"] = rio.transform.from_origin(  # type: ignore
        profile["transform"][2],
        profile["transform"][5],
        resolution,
        resolution,
    )
    data = np.array(bands)
    profile["height"] = data.shape[1]
    profile["width"] = data.shape[2]
    return data, profile


def load_multiband(
    input_path: Union[Path, str],
    resample_res: Optional[float] = None,
    band_order: Optional[list[int]] = None,
) -> tuple[np.ndarray, Profile]:
    """Load a multiband image and resample it to requested resolution."""
    if band_order is None:
        warnings.warn(
            "No band order provided, using default [1, 2, 3] (RGN)", UserWarning
        )
        band_order = [1, 2, 3]
    input_path = Path(input_path)

    with rio.open(input_path) as src:
        if resample_res:
            current_res = src.res
            desired_res = (resample_res, resample_res)
            scale_factor = (
                current_res[0] / desired_res[0],
                current_res[1] / desired_res[1],
            )
        else:
            scale_factor = (1, 1)

        data = src.read(
            band_order,
            out_shape=(
                len(band_order),
                int(src.height * scale_factor[0]),
                int(src.width * scale_factor[1]),
            ),
            resampling=rio.enums.Resampling.nearest,  # type: ignore
        )
        profile = src.profile

        return data, profile


def load_ls8(
    input_path: Union[Path, str],
    resolution: int = 30,
    required_bands=["B4", "B3", "B5"],
) -> tuple[np.ndarray, Profile]:
    """Load a Landsat 8 image from a folder containing the bands"""
    if resolution != 30:
        raise ValueError("Resolution must be 30")

    input_path = Path(input_path)

    band_files = {}
    for band_name in required_bands:
        try:
            band = list(input_path.rglob(f"*{band_name}.TIF"))[0]

        except IndexError:
            raise ValueError(f"Band {band_name} not found in {input_path}")
        band_files[band_name] = band

    data = []
    profile = Profile()
    for band_name in required_bands:
        with rio.open(band_files[band_name]) as src:
            if not profile:
                profile = src.profile
            data.append(src.read(1))

    data = np.array(data)
    return data, profile