|
from functools import partial |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
import timm |
|
import torch |
|
from fastai.vision.learner import create_unet_model |
|
|
|
|
|
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype: |
|
"""Return a torch.dtype from a string or torch.dtype.""" |
|
if isinstance(dtype, str): |
|
dtype_mapping = { |
|
"float16": torch.float16, |
|
"half": torch.float16, |
|
"fp16": torch.float16, |
|
"float32": torch.float32, |
|
"float": torch.float32, |
|
"bfloat16": torch.bfloat16, |
|
"bf16": torch.bfloat16, |
|
} |
|
try: |
|
return dtype_mapping[dtype.lower()] |
|
except KeyError: |
|
raise ValueError( |
|
f"Invalid dtype: {dtype}. Must be one of {list(dtype_mapping.keys())}" |
|
) |
|
elif isinstance(dtype, torch.dtype): |
|
return dtype |
|
else: |
|
raise TypeError( |
|
f"Expected dtype to be a str or torch.dtype, but got {type(dtype)}" |
|
) |
|
|
|
|
|
def create_gradient_mask( |
|
patch_size: int, patch_overlap: int, device: torch.device, dtype: torch.dtype |
|
) -> torch.Tensor: |
|
"""Create a gradient mask for a given patch size and overlap.""" |
|
if patch_overlap > 0: |
|
if patch_overlap * 2 > patch_size: |
|
patch_overlap = patch_size // 2 |
|
|
|
gradient_strength = 1 |
|
gradient = ( |
|
torch.ones((patch_size, patch_size), dtype=torch.int, device=device) |
|
* patch_overlap |
|
) |
|
gradient[:, :patch_overlap] = torch.tile( |
|
torch.arange(1, patch_overlap + 1), |
|
(patch_size, 1), |
|
) |
|
gradient[:, -patch_overlap:] = torch.tile( |
|
torch.arange(patch_overlap, 0, -1), |
|
(patch_size, 1), |
|
) |
|
gradient = gradient / patch_overlap |
|
rotated_gradient = torch.rot90(gradient) |
|
combined_gradient = rotated_gradient * gradient |
|
|
|
combined_gradient = (combined_gradient * gradient_strength) + ( |
|
1 - gradient_strength |
|
) |
|
else: |
|
combined_gradient = torch.ones( |
|
(patch_size, patch_size), dtype=torch.int, device=device |
|
) |
|
return combined_gradient.to(dtype) |
|
|
|
|
|
def channel_norm(patch: np.ndarray, nodata_value: Optional[int] = 0) -> np.ndarray: |
|
"""Normalize each band of the input array by subtracting the nonzero mean and dividing |
|
by the nonzero standard deviation then fill nodata values with 0.""" |
|
out_array = np.zeros(patch.shape).astype(np.float32) |
|
for id, band in enumerate(patch): |
|
|
|
mask = band != nodata_value |
|
|
|
if np.any(mask): |
|
mean = band[mask].mean() |
|
std = band[mask].std() |
|
if std == 0: |
|
std = 1 |
|
|
|
out_array[id][mask] = (band[mask] - mean) / std |
|
else: |
|
continue |
|
|
|
out_array[id][~mask] = 0 |
|
return out_array |
|
|
|
|
|
def store_results( |
|
pred_batch: torch.Tensor, |
|
index_batch: list[tuple], |
|
pred_tracker: torch.Tensor, |
|
gradient: torch.Tensor, |
|
grad_tracker: Optional[torch.Tensor] = None, |
|
) -> None: |
|
"""Store the results of the model inference in the pred_tracker and grad_tracker tensors.""" |
|
|
|
assert pred_batch.ndim == 4, "pred_batch must have 4 dimensions, (B, class, H, W)" |
|
assert pred_batch.shape[0] == len(index_batch), "Batch size must match index_batch" |
|
assert pred_batch.shape[1] == pred_tracker.shape[0], "Number of classes must match" |
|
assert pred_batch.shape[2] == gradient.shape[0], "Height must match gradient" |
|
assert pred_batch.shape[3] == gradient.shape[1], "Width must match gradient" |
|
|
|
pred_batch *= gradient[None, None, :, :] |
|
|
|
for pred, index in zip(pred_batch.to(pred_tracker.device), index_batch): |
|
pred_tracker[:, index[0] : index[1], index[2] : index[3]] += pred |
|
if grad_tracker is not None: |
|
grad_tracker[index[0] : index[1], index[2] : index[3]] += gradient.to( |
|
grad_tracker.device |
|
) |
|
|
|
|
|
def inference_and_store( |
|
models: list[torch.nn.Module], |
|
patch_batch: torch.Tensor, |
|
index_batch: list[tuple], |
|
pred_tracker: torch.Tensor, |
|
gradient: torch.Tensor, |
|
grad_tracker: Optional[torch.Tensor] = None, |
|
) -> None: |
|
"""Perform inference on the patch_batch and store the results in the pred_tracker and grad_tracker tensors.""" |
|
|
|
all_preds = torch.zeros( |
|
len(models), |
|
patch_batch.shape[0], |
|
pred_tracker.shape[0], |
|
patch_batch.shape[2], |
|
patch_batch.shape[3], |
|
device=patch_batch.device, |
|
dtype=patch_batch.dtype, |
|
) |
|
for index, model in enumerate(models): |
|
with torch.no_grad(): |
|
all_preds[index] = model(patch_batch) |
|
|
|
mean_preds = all_preds.mean(dim=0) |
|
|
|
store_results( |
|
pred_batch=mean_preds, |
|
index_batch=index_batch, |
|
pred_tracker=pred_tracker, |
|
gradient=gradient, |
|
grad_tracker=grad_tracker, |
|
) |
|
|
|
|
|
def default_device() -> torch.device: |
|
"""Return the default device for model inference""" |
|
if torch.cuda.is_available(): |
|
return torch.device("cuda") |
|
elif torch.backends.mps.is_available(): |
|
return torch.device("mps") |
|
return torch.device("cpu") |
|
|
|
|
|
def load_model( |
|
model_path: Union[Path, str], |
|
device: torch.device, |
|
dtype: torch.dtype = torch.float32, |
|
) -> torch.nn.Module: |
|
"""Load a PyTorch model from a file and move it to the specified device and dtype.""" |
|
model_path = Path(model_path) |
|
if not model_path.is_file(): |
|
raise FileNotFoundError(f"Model file not found at: {model_path}") |
|
|
|
try: |
|
model = torch.load(model_path, map_location="cpu") |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model: {e}") |
|
|
|
model.eval() |
|
return model.to(dtype).to(device) |
|
|
|
|
|
def load_model_from_weights( |
|
model_name: str, |
|
weights_path: Union[Path, str], |
|
device: torch.device, |
|
dtype: torch.dtype = torch.float32, |
|
in_chans: int = 3, |
|
n_out: int = 4, |
|
) -> torch.nn.Module: |
|
"""Build Fastai DynamicUnet model from timm model and load weights from file""" |
|
timm_model = partial( |
|
timm.create_model, |
|
model_name=model_name, |
|
pretrained=False, |
|
in_chans=in_chans, |
|
) |
|
|
|
model = create_unet_model( |
|
arch=timm_model, |
|
n_out=n_out, |
|
img_size=(509, 509), |
|
act_cls=torch.nn.Mish, |
|
pretrained=False, |
|
) |
|
|
|
model.load_state_dict(torch.load(weights_path, weights_only=True)) |
|
model.eval() |
|
|
|
return model.to(dtype).to(device) |
|
|