|
import torch |
|
import pathlib |
|
import importlib.util |
|
import safetensors.torch |
|
import matplotlib.pyplot as plt |
|
import math |
|
|
|
from typing import Literal |
|
|
|
|
|
def load_model_module(model_path: pathlib.Path): |
|
model_path = model_path.resolve() |
|
|
|
spec = importlib.util.spec_from_file_location("model", model_path) |
|
model = importlib.util.module_from_spec(spec) |
|
spec.loader.exec_module(model) |
|
|
|
return model |
|
|
|
class EnsembleModel(torch.nn.Module): |
|
def __init__(self, model1, model2, model3, model4, model5, model6, mode="max"): |
|
super(EnsembleModel, self).__init__() |
|
self.model1 = model1 |
|
self.model2 = model2 |
|
self.model3 = model3 |
|
self.model4 = model4 |
|
self.model5 = model5 |
|
self.model6 = model6 |
|
self.models = [model1, model2, model3, model4, model5, model6] |
|
self.mode = mode |
|
if mode not in ["min", "mean", "max", "none"]: |
|
raise ValueError("Mode must be 'none', 'min', 'mean', or 'max'.") |
|
|
|
def forward(self, x): |
|
outputs = [] |
|
for model in self.models: |
|
output = model(x) |
|
outputs.append(output) |
|
|
|
|
|
if self.mode == "max": |
|
output_probs = torch.max(torch.cat(outputs, dim=1), dim=1)[0].squeeze() |
|
elif self.mode == "mean": |
|
output_probs = torch.mean(torch.cat(outputs, dim=1), dim=1)[0].squeeze() |
|
elif self.mode == "min": |
|
output_probs = torch.min(torch.cat(outputs, dim=1), dim=1)[0].squeeze() |
|
elif self.mode == "none": |
|
return torch.cat(outputs, dim=1) |
|
else: |
|
raise ValueError("Mode must be 'min', 'mean', or 'max'.") |
|
|
|
|
|
std_output = torch.std(torch.cat(outputs, dim=1), dim=1)[0].squeeze() |
|
|
|
|
|
N = len(outputs) |
|
std_max = math.sqrt(0.25 * N / (N - 1)) |
|
std_output = std_output / std_max |
|
|
|
return output_probs, std_output |
|
|
|
|
|
|
|
def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs): |
|
data_f = path / "example_data.safetensor" |
|
sample = safetensors.torch.load_file(data_f) |
|
return sample["image"].float().unsqueeze(0).to(device) |
|
|
|
def trainable_model(*args, **kwargs): |
|
print("The model is not available in training mode.") |
|
return None |
|
|
|
def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max",*args, **kwargs): |
|
|
|
model1_f = path / "1dpwdeeplabv3.safetensor" |
|
model2_f = path / "1dpwseg.safetensor" |
|
model3_f = path / "1dpwunetpp.safetensor" |
|
model4_f = path / "unet.safetensor" |
|
model5_f = path / "unetpp.safetensor" |
|
model6_f = path / "c2r1km.safetensor" |
|
|
|
|
|
model1_weights = safetensors.torch.load_file(model1_f) |
|
model2_weights = safetensors.torch.load_file(model2_f) |
|
model3_weights = safetensors.torch.load_file(model3_f) |
|
model4_weights = safetensors.torch.load_file(model4_f) |
|
model5_weights = safetensors.torch.load_file(model5_f) |
|
model6_weights = safetensors.torch.load_file(model6_f) |
|
|
|
|
|
|
|
|
|
model1 = load_model_module(path / "model.py").CombinedNet4( |
|
classes=1, benchmark=True, in_channels=4 |
|
) |
|
model1.load_state_dict(model1_weights) |
|
model1 = model1.to(device) |
|
for param in model1.parameters(): |
|
param.requires_grad = False |
|
model1 = model1.eval() |
|
|
|
|
|
|
|
model2 = load_model_module(path / "model.py").CombinedNet( |
|
classes=1, benchmark=True |
|
) |
|
model2.load_state_dict(model2_weights) |
|
model2 = model2.to(device) |
|
for param in model2.parameters(): |
|
param.requires_grad = False |
|
model2 = model2.eval() |
|
|
|
|
|
model3 = load_model_module(path / "model.py").CombinedNet3( |
|
classes=1, benchmark=True |
|
) |
|
model3.load_state_dict(model3_weights) |
|
model3 = model3.to(device) |
|
for param in model3.parameters(): |
|
param.requires_grad = False |
|
model3 = model3.eval() |
|
|
|
|
|
model4 = load_model_module(path / "model.py").UNetBranch( |
|
classes=1, benchmark=True |
|
) |
|
model4.load_state_dict(model4_weights) |
|
model4 = model4.to(device) |
|
for param in model4.parameters(): |
|
param.requires_grad = False |
|
model4 = model4.eval() |
|
|
|
|
|
model5 = load_model_module(path / "model.py").UNetPlusPlusBranch( |
|
classes=1, benchmark=True |
|
) |
|
model5.load_state_dict(model5_weights) |
|
model5 = model5.to(device) |
|
for param in model5.parameters(): |
|
param.requires_grad = False |
|
model5 = model5.eval() |
|
|
|
|
|
model6 = load_model_module(path / "c2r1km.py").CloudMaskOne( |
|
hidden_layer_sizes=(128, 112), |
|
activation='relu', |
|
last_activation='sigmoid', |
|
dropout_rate=0.1, |
|
input_dim=40, |
|
batch_norm=False |
|
) |
|
model6.load_state_dict(model6_weights) |
|
model6 = model6.to(device) |
|
for param in model6.parameters(): |
|
param.requires_grad = False |
|
model6 = model6.eval() |
|
|
|
|
|
cloud_model = EnsembleModel(model1, model2, model3, model4, model5, model6, mode=mode) |
|
|
|
return cloud_model |
|
|
|
|
|
|
|
def display_results(path: pathlib.Path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max", *args, **kwargs): |
|
|
|
model = compiled_model(path, device, mode=mode) |
|
|
|
|
|
probav = example_data(path) |
|
|
|
|
|
cloudprob, uncertainty = model(probav.float().to(device)) |
|
|
|
|
|
fig, ax = plt.subplots(1, 3, figsize=(12, 4)) |
|
ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0)) |
|
ax[0].set_title("Input") |
|
ax[1].imshow(cloudprob.cpu().detach().numpy(), cmap="gray") |
|
ax[1].set_title("Cloud Probability") |
|
ax[2].imshow(uncertainty.cpu().detach().numpy(), cmap="gray") |
|
ax[2].set_title("Uncertainty") |
|
for a in ax: |
|
a.axis("off") |
|
fig.tight_layout() |
|
return fig |