JulioContrerasH's picture
Update ensemble/load.py
567a505 verified
import torch
import pathlib
import importlib.util
import safetensors.torch
import matplotlib.pyplot as plt
import math
from typing import Literal
def load_model_module(model_path: pathlib.Path):
model_path = model_path.resolve()
spec = importlib.util.spec_from_file_location("model", model_path)
model = importlib.util.module_from_spec(spec)
spec.loader.exec_module(model)
return model
class EnsembleModel(torch.nn.Module):
def __init__(self, model1, model2, model3, model4, model5, model6, mode="max"):
super(EnsembleModel, self).__init__()
self.model1 = model1
self.model2 = model2
self.model3 = model3
self.model4 = model4
self.model5 = model5
self.model6 = model6
self.models = [model1, model2, model3, model4, model5, model6]
self.mode = mode
if mode not in ["min", "mean", "max", "none"]:
raise ValueError("Mode must be 'none', 'min', 'mean', or 'max'.")
def forward(self, x):
outputs = []
for model in self.models:
output = model(x)
outputs.append(output)
# Average the outputs
if self.mode == "max":
output_probs = torch.max(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
elif self.mode == "mean":
output_probs = torch.mean(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
elif self.mode == "min":
output_probs = torch.min(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
elif self.mode == "none":
return torch.cat(outputs, dim=1)
else:
raise ValueError("Mode must be 'min', 'mean', or 'max'.")
# Kind of uncertainty
std_output = torch.std(torch.cat(outputs, dim=1), dim=1)[0].squeeze()
# Normalize the standard deviation [0 - 1]
N = len(outputs)
std_max = math.sqrt(0.25 * N / (N - 1))
std_output = std_output / std_max
return output_probs, std_output
# MLSTAC API -----------------------------------------------------------------------
def example_data(path: pathlib.Path, device = "cpu", *args, **kwargs):
data_f = path / "example_data.safetensor"
sample = safetensors.torch.load_file(data_f)
return sample["image"].float().unsqueeze(0).to(device)
def trainable_model(*args, **kwargs):
print("The model is not available in training mode.")
return None
def compiled_model(path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max",*args, **kwargs):
model1_f = path / "1dpwdeeplabv3.safetensor"
model2_f = path / "1dpwseg.safetensor"
model3_f = path / "1dpwunetpp.safetensor"
model4_f = path / "unet.safetensor"
model5_f = path / "unetpp.safetensor"
model6_f = path / "c2r1km.safetensor"
# Load model parameters
model1_weights = safetensors.torch.load_file(model1_f)
model2_weights = safetensors.torch.load_file(model2_f)
model3_weights = safetensors.torch.load_file(model3_f)
model4_weights = safetensors.torch.load_file(model4_f)
model5_weights = safetensors.torch.load_file(model5_f)
model6_weights = safetensors.torch.load_file(model6_f)
# Load all models
# Model 1 (DeepLabV3Branch + PixelWise)
model1 = load_model_module(path / "model.py").CombinedNet4(
classes=1, benchmark=True, in_channels=4
)
model1.load_state_dict(model1_weights)
model1 = model1.to(device)
for param in model1.parameters():
param.requires_grad = False
model1 = model1.eval()
# Model 2 (SegformerBranch + PixelWise)
model2 = load_model_module(path / "model.py").CombinedNet(
classes=1, benchmark=True
)
model2.load_state_dict(model2_weights)
model2 = model2.to(device)
for param in model2.parameters():
param.requires_grad = False
model2 = model2.eval()
# Model 3 (UNetPlusPlusBranch + PixelWise)
model3 = load_model_module(path / "model.py").CombinedNet3(
classes=1, benchmark=True
)
model3.load_state_dict(model3_weights)
model3 = model3.to(device)
for param in model3.parameters():
param.requires_grad = False
model3 = model3.eval()
# Model 4 (UNetBranch)
model4 = load_model_module(path / "model.py").UNetBranch(
classes=1, benchmark=True
)
model4.load_state_dict(model4_weights)
model4 = model4.to(device)
for param in model4.parameters():
param.requires_grad = False
model4 = model4.eval()
# Model 5 (UNetPlusPlusBranch)
model5 = load_model_module(path / "model.py").UNetPlusPlusBranch(
classes=1, benchmark=True
)
model5.load_state_dict(model5_weights)
model5 = model5.to(device)
for param in model5.parameters():
param.requires_grad = False
model5 = model5.eval()
# Model 6 (C2R1KM)
model6 = load_model_module(path / "c2r1km.py").CloudMaskOne(
hidden_layer_sizes=(128, 112),
activation='relu',
last_activation='sigmoid',
dropout_rate=0.1,
input_dim=40,
batch_norm=False
)
model6.load_state_dict(model6_weights)
model6 = model6.to(device)
for param in model6.parameters():
param.requires_grad = False
model6 = model6.eval()
# Create ensemble model
cloud_model = EnsembleModel(model1, model2, model3, model4, model5, model6, mode=mode)
return cloud_model
def display_results(path: pathlib.Path, device: str = "cpu", mode: Literal["min", "mean", "max"] ="max", *args, **kwargs):
# Load model
model = compiled_model(path, device, mode=mode)
# Load data
probav = example_data(path)
# Run model
cloudprob, uncertainty = model(probav.float().to(device))
#Display results
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(probav[0, [2, 1, 0]].cpu().detach().numpy().transpose(1, 2, 0))
ax[0].set_title("Input")
ax[1].imshow(cloudprob.cpu().detach().numpy(), cmap="gray")
ax[1].set_title("Cloud Probability")
ax[2].imshow(uncertainty.cpu().detach().numpy(), cmap="gray")
ax[2].set_title("Uncertainty")
for a in ax:
a.axis("off")
fig.tight_layout()
return fig