File size: 6,898 Bytes
241b6a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
from functools import partial
from pathlib import Path
from typing import Optional, Union

import numpy as np
import timm
import torch
from fastai.vision.learner import create_unet_model


def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
    """Return a torch.dtype from a string or torch.dtype."""
    if isinstance(dtype, str):
        dtype_mapping = {
            "float16": torch.float16,
            "half": torch.float16,
            "fp16": torch.float16,
            "float32": torch.float32,
            "float": torch.float32,
            "bfloat16": torch.bfloat16,
            "bf16": torch.bfloat16,
        }
        try:
            return dtype_mapping[dtype.lower()]
        except KeyError:
            raise ValueError(
                f"Invalid dtype: {dtype}. Must be one of {list(dtype_mapping.keys())}"
            )
    elif isinstance(dtype, torch.dtype):
        return dtype
    else:
        raise TypeError(
            f"Expected dtype to be a str or torch.dtype, but got {type(dtype)}"
        )


def create_gradient_mask(
    patch_size: int, patch_overlap: int, device: torch.device, dtype: torch.dtype
) -> torch.Tensor:
    """Create a gradient mask for a given patch size and overlap."""
    if patch_overlap > 0:
        if patch_overlap * 2 > patch_size:
            patch_overlap = patch_size // 2

        gradient_strength = 1
        gradient = (
            torch.ones((patch_size, patch_size), dtype=torch.int, device=device)
            * patch_overlap
        )
        gradient[:, :patch_overlap] = torch.tile(
            torch.arange(1, patch_overlap + 1),
            (patch_size, 1),
        )
        gradient[:, -patch_overlap:] = torch.tile(
            torch.arange(patch_overlap, 0, -1),
            (patch_size, 1),
        )
        gradient = gradient / patch_overlap
        rotated_gradient = torch.rot90(gradient)
        combined_gradient = rotated_gradient * gradient

        combined_gradient = (combined_gradient * gradient_strength) + (
            1 - gradient_strength
        )
    else:
        combined_gradient = torch.ones(
            (patch_size, patch_size), dtype=torch.int, device=device
        )
    return combined_gradient.to(dtype)


def channel_norm(patch: np.ndarray, nodata_value: Optional[int] = 0) -> np.ndarray:
    """Normalize each band of the input array by subtracting the nonzero mean and dividing
    by the nonzero standard deviation then fill nodata values with 0."""
    out_array = np.zeros(patch.shape).astype(np.float32)
    for id, band in enumerate(patch):
        # Mask for non-zero values
        mask = band != nodata_value
        # Check if there are any non-zero values
        if np.any(mask):
            mean = band[mask].mean()
            std = band[mask].std()
            if std == 0:
                std = 1  # Prevent division by zero
            # Normalize only non-zero values
            out_array[id][mask] = (band[mask] - mean) / std
        else:
            continue
        # Fill original nodata values with 0
        out_array[id][~mask] = 0
    return out_array


def store_results(
    pred_batch: torch.Tensor,
    index_batch: list[tuple],
    pred_tracker: torch.Tensor,
    gradient: torch.Tensor,
    grad_tracker: Optional[torch.Tensor] = None,
) -> None:
    """Store the results of the model inference in the pred_tracker and grad_tracker tensors."""
    # Store the predictions in the pred_tracker tensor
    assert pred_batch.ndim == 4, "pred_batch must have 4 dimensions, (B, class, H, W)"
    assert pred_batch.shape[0] == len(index_batch), "Batch size must match index_batch"
    assert pred_batch.shape[1] == pred_tracker.shape[0], "Number of classes must match"
    assert pred_batch.shape[2] == gradient.shape[0], "Height must match gradient"
    assert pred_batch.shape[3] == gradient.shape[1], "Width must match gradient"

    pred_batch *= gradient[None, None, :, :]

    for pred, index in zip(pred_batch.to(pred_tracker.device), index_batch):
        pred_tracker[:, index[0] : index[1], index[2] : index[3]] += pred
        if grad_tracker is not None:
            grad_tracker[index[0] : index[1], index[2] : index[3]] += gradient.to(
                grad_tracker.device
            )


def inference_and_store(
    models: list[torch.nn.Module],
    patch_batch: torch.Tensor,
    index_batch: list[tuple],
    pred_tracker: torch.Tensor,
    gradient: torch.Tensor,
    grad_tracker: Optional[torch.Tensor] = None,
) -> None:
    """Perform inference on the patch_batch and store the results in the pred_tracker and grad_tracker tensors."""
    # pre-initialize the all_preds tensor to store the predictions from each model
    all_preds = torch.zeros(
        len(models),
        patch_batch.shape[0],
        pred_tracker.shape[0],
        patch_batch.shape[2],
        patch_batch.shape[3],
        device=patch_batch.device,
        dtype=patch_batch.dtype,
    )
    for index, model in enumerate(models):
        with torch.no_grad():
            all_preds[index] = model(patch_batch)

    mean_preds = all_preds.mean(dim=0)

    store_results(
        pred_batch=mean_preds,
        index_batch=index_batch,
        pred_tracker=pred_tracker,
        gradient=gradient,
        grad_tracker=grad_tracker,
    )


def default_device() -> torch.device:
    """Return the default device for model inference"""
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def load_model(
    model_path: Union[Path, str],
    device: torch.device,
    dtype: torch.dtype = torch.float32,
) -> torch.nn.Module:
    """Load a PyTorch model from a file and move it to the specified device and dtype."""
    model_path = Path(model_path)
    if not model_path.is_file():
        raise FileNotFoundError(f"Model file not found at: {model_path}")

    try:
        model = torch.load(model_path, map_location="cpu")
    except Exception as e:
        raise RuntimeError(f"Error loading model: {e}")

    model.eval()
    return model.to(dtype).to(device)


def load_model_from_weights(
    model_name: str,
    weights_path: Union[Path, str],
    device: torch.device,
    dtype: torch.dtype = torch.float32,
    in_chans: int = 3,
    n_out: int = 4,
) -> torch.nn.Module:
    """Build Fastai DynamicUnet model from timm model and load weights from file"""
    timm_model = partial(
        timm.create_model,
        model_name=model_name,
        pretrained=False,
        in_chans=in_chans,
    )

    model = create_unet_model(
        arch=timm_model,
        n_out=n_out,
        img_size=(509, 509),
        act_cls=torch.nn.Mish,
        pretrained=False,
    )

    model.load_state_dict(torch.load(weights_path, weights_only=True))
    model.eval()

    return model.to(dtype).to(device)