|
import torch |
|
import torch.nn as nn |
|
import pytorch_lightning as pl |
|
from torchmetrics import classification |
|
import wandb |
|
from matplotlib import pyplot as plt |
|
import numpy as np |
|
import matplotlib.ticker as ticker |
|
from matplotlib.colors import ListedColormap |
|
from huggingface_hub import PyTorchModelHubMixin |
|
from lion_pytorch import Lion |
|
|
|
import json |
|
|
|
from messis.prithvi import TemporalViTEncoder, ConvTransformerTokensToEmbeddingNeck, ConvTransformerTokensToEmbeddingBottleneckNeck |
|
|
|
def safe_shape(x): |
|
if isinstance(x, tuple): |
|
|
|
shape_info = '(tuple) : ' |
|
for i in x: |
|
shape_info += str(i.shape) + ', ' |
|
return shape_info |
|
if isinstance(x, list): |
|
|
|
shape_info = '(list) : ' |
|
for i in x: |
|
shape_info += str(i.shape) + ', ' |
|
return shape_info |
|
return x.shape |
|
|
|
class ConvModule(nn.Module): |
|
""" |
|
A simple convolutional module including Conv, BatchNorm, and ReLU layers. |
|
""" |
|
def __init__(self, in_channels, out_channels, kernel_size, padding, dilation): |
|
super(ConvModule, self).__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False) |
|
self.bn = nn.BatchNorm2d(out_channels) |
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
x = self.bn(x) |
|
return self.relu(x) |
|
|
|
class HierarchicalFCNHead(nn.Module): |
|
""" |
|
Hierarchical FCN Head for semantic segmentation. |
|
""" |
|
def __init__(self, in_channels, out_channels, num_classes, num_convs=2, kernel_size=3, dilation=1, dropout_p=0.1, debug=False): |
|
super(HierarchicalFCNHead, self).__init__() |
|
|
|
self.debug = debug |
|
|
|
self.convs = nn.Sequential(*[ |
|
ConvModule( |
|
in_channels if i == 0 else out_channels, |
|
out_channels, |
|
kernel_size, |
|
padding=dilation * (kernel_size // 2), |
|
dilation=dilation |
|
) for i in range(num_convs) |
|
]) |
|
|
|
self.conv_seg = nn.Conv2d(out_channels, num_classes, kernel_size=1) |
|
self.dropout = nn.Dropout2d(p=dropout_p) |
|
|
|
def forward(self, x): |
|
if self.debug: |
|
print('HierarchicalFCNHead forward INP: ', safe_shape(x)) |
|
x = self.convs(x) |
|
features = self.dropout(x) |
|
output = self.conv_seg(features) |
|
if self.debug: |
|
print('HierarchicalFCNHead forward features OUT: ', safe_shape(features)) |
|
print('HierarchicalFCNHead forward output OUT: ', safe_shape(output)) |
|
return output, features |
|
|
|
class LabelRefinementHead(nn.Module): |
|
""" |
|
Similar to the label refinement module introduced in the ZueriCrop paper, this module refines the predictions for tier 3. |
|
It takes the raw predictions from head 1, head 2 and head 3 and refines them to produce the final prediction for tier 3. |
|
According to ZueriCrop, this helps with making the predictions more consistent across the different tiers. |
|
""" |
|
def __init__(self, input_channels, num_classes): |
|
super(LabelRefinementHead, self).__init__() |
|
|
|
self.cnn_layers = nn.Sequential( |
|
|
|
nn.Conv2d(in_channels=input_channels, out_channels=128, kernel_size=1, stride=1, padding=0), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout(p=0.5), |
|
|
|
|
|
|
|
|
|
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), |
|
nn.BatchNorm2d(128), |
|
nn.ReLU(inplace=True), |
|
|
|
|
|
nn.Conv2d(in_channels=128, out_channels=num_classes, kernel_size=1, stride=1, padding=0), |
|
nn.Dropout(p=0.5) |
|
) |
|
|
|
def forward(self, x): |
|
|
|
y = self.cnn_layers[0:3](x) |
|
|
|
|
|
y_skip = y |
|
|
|
|
|
y = self.cnn_layers[3:9](y) |
|
|
|
|
|
y = y + y_skip |
|
|
|
|
|
y = self.cnn_layers[9:](y) |
|
return y |
|
|
|
class HierarchicalClassifier(nn.Module): |
|
def __init__( |
|
self, |
|
heads_spec, |
|
dropout_p=0.1, |
|
img_size=256, |
|
patch_size=16, |
|
num_frames=3, |
|
bands=[0, 1, 2, 3, 4, 5], |
|
backbone_weights_path=None, |
|
freeze_backbone=True, |
|
use_bottleneck_neck=False, |
|
bottleneck_reduction_factor=4, |
|
loss_ignore_background=False, |
|
debug=False |
|
): |
|
super(HierarchicalClassifier, self).__init__() |
|
|
|
self.embed_dim = 768 |
|
if num_frames % 3 != 0: |
|
raise ValueError("The number of frames must be a multiple of 3, it is currently: ", num_frames) |
|
self.num_frames = num_frames |
|
self.hp, self.wp = img_size // patch_size, img_size // patch_size |
|
self.heads_spec = heads_spec |
|
self.dropout_p = dropout_p |
|
self.loss_ignore_background = loss_ignore_background |
|
self.debug = debug |
|
|
|
if self.debug: |
|
print('hp and wp: ', self.hp, self.wp) |
|
|
|
self.prithvi = TemporalViTEncoder( |
|
img_size=img_size, |
|
patch_size=patch_size, |
|
num_frames=3, |
|
tubelet_size=1, |
|
in_chans=len(bands), |
|
embed_dim=self.embed_dim, |
|
depth=12, |
|
num_heads=8, |
|
mlp_ratio=4.0, |
|
norm_pix_loss=False, |
|
pretrained=backbone_weights_path, |
|
debug=self.debug |
|
) |
|
|
|
|
|
for param in self.prithvi.parameters(): |
|
param.requires_grad = not freeze_backbone |
|
|
|
|
|
number_of_necks = self.num_frames // 3 |
|
if use_bottleneck_neck: |
|
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingBottleneckNeck( |
|
embed_dim=self.embed_dim * 3, |
|
output_embed_dim=self.embed_dim * 3, |
|
drop_cls_token=True, |
|
Hp=self.hp, |
|
Wp=self.wp, |
|
bottleneck_reduction_factor=bottleneck_reduction_factor |
|
) for _ in range(number_of_necks)]) |
|
else: |
|
self.necks = nn.ModuleList([ConvTransformerTokensToEmbeddingNeck( |
|
embed_dim=self.embed_dim * 3, |
|
output_embed_dim=self.embed_dim * 3, |
|
drop_cls_token=True, |
|
Hp=self.hp, |
|
Wp=self.wp, |
|
) for _ in range(number_of_necks)]) |
|
|
|
|
|
self.heads = nn.ModuleDict() |
|
self.loss_weights = {} |
|
self.total_classes = 0 |
|
|
|
|
|
head_count = 0 |
|
for head_name, head_info in self.heads_spec.items(): |
|
head_type = head_info['type'] |
|
num_classes = head_info['num_classes_to_predict'] |
|
loss_weight = head_info['loss_weight'] |
|
|
|
if head_type == 'HierarchicalFCNHead': |
|
num_classes = head_info['num_classes_to_predict'] |
|
loss_weight = head_info['loss_weight'] |
|
kernel_size = head_info.get('kernel_size', 3) |
|
num_convs = head_info.get('num_convs', 1) |
|
num_channels = head_info.get('num_channels', 256) |
|
self.total_classes += num_classes |
|
|
|
self.heads[head_name] = HierarchicalFCNHead( |
|
in_channels=(self.embed_dim * self.num_frames) if head_count == 0 else num_channels, |
|
out_channels=num_channels, |
|
num_classes=num_classes, |
|
num_convs=num_convs, |
|
kernel_size=kernel_size, |
|
dropout_p=self.dropout_p, |
|
debug=self.debug |
|
) |
|
self.loss_weights[head_name] = loss_weight |
|
|
|
|
|
if head_type == 'LabelRefinementHead': |
|
self.refinement_head = LabelRefinementHead(input_channels=self.total_classes, num_classes=num_classes) |
|
self.refinement_head_name = head_name |
|
self.loss_weights[head_name] = loss_weight |
|
|
|
head_count += 1 |
|
|
|
self.loss_func = nn.CrossEntropyLoss(ignore_index=-1) |
|
|
|
def forward(self, x): |
|
if self.debug: |
|
print(f"Input shape: {safe_shape(x)}") |
|
|
|
|
|
if len(self.necks) == 1: |
|
features = [x] |
|
else: |
|
features = torch.chunk(x, len(self.necks), dim=2) |
|
features = [self.prithvi(x) for x in features] |
|
|
|
if self.debug: |
|
print(f"Features shape after base model: {', '.join([safe_shape(f) for f in features])}") |
|
|
|
|
|
features = [neck(feat_) for feat_, neck in zip(features, self.necks)] |
|
|
|
if self.debug: |
|
print(f"Features shape after neck: {', '.join([safe_shape(f) for f in features])}") |
|
|
|
|
|
features = [feat[0] for feat in features] |
|
|
|
features = torch.concatenate(features, dim=1) |
|
if self.debug: |
|
print(f"Features shape after removing tuple: {safe_shape(features)}") |
|
|
|
|
|
outputs = {} |
|
for tier_name, head in self.heads.items(): |
|
output, features = head(features) |
|
outputs[tier_name] = output |
|
|
|
if self.debug: |
|
print(f"Features shape after {tier_name} head: {safe_shape(features)}") |
|
print(f"Output shape after {tier_name} head: {safe_shape(output)}") |
|
|
|
|
|
output_concatenated = torch.cat(list(outputs.values()), dim=1) |
|
output_refinement_head = self.refinement_head(output_concatenated) |
|
outputs[self.refinement_head_name] = output_refinement_head |
|
|
|
return outputs |
|
|
|
def calculate_loss(self, outputs, targets): |
|
total_loss = 0 |
|
loss_per_head = {} |
|
for head_name, output in outputs.items(): |
|
if self.debug: |
|
print(f"Target index for {head_name}: {self.heads_spec[head_name]['target_idx']}") |
|
target = targets[self.heads_spec[head_name]['target_idx']] |
|
loss_target = target |
|
if self.loss_ignore_background: |
|
loss_target = target.clone() |
|
loss_target[loss_target == 0] = -1 |
|
loss = self.loss_func(output, loss_target) |
|
loss_per_head[f'{head_name}'] = loss |
|
total_loss += loss * self.loss_weights[head_name] |
|
|
|
return total_loss, loss_per_head |
|
|
|
class Messis(pl.LightningModule, PyTorchModelHubMixin): |
|
def __init__(self, hparams): |
|
super().__init__() |
|
self.save_hyperparameters(hparams) |
|
|
|
self.model = HierarchicalClassifier( |
|
heads_spec=hparams['heads_spec'], |
|
dropout_p=hparams.get('dropout_p'), |
|
img_size=hparams.get('img_size'), |
|
patch_size=hparams.get('patch_size'), |
|
num_frames=hparams.get('num_frames'), |
|
bands=hparams.get('bands'), |
|
backbone_weights_path=hparams.get('backbone_weights_path'), |
|
freeze_backbone=hparams['freeze_backbone'], |
|
use_bottleneck_neck=hparams.get('use_bottleneck_neck'), |
|
bottleneck_reduction_factor=hparams.get('bottleneck_reduction_factor'), |
|
loss_ignore_background=hparams.get('loss_ignore_background'), |
|
debug=hparams.get('debug') |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
def training_step(self, batch, batch_idx): |
|
return self.__step(batch, batch_idx, "train") |
|
|
|
def validation_step(self, batch, batch_idx): |
|
return self.__step(batch, batch_idx, "val") |
|
|
|
def test_step(self, batch, batch_idx): |
|
return self.__step(batch, batch_idx, "test") |
|
|
|
def configure_optimizers(self): |
|
|
|
match self.hparams.get('optimizer', 'Adam'): |
|
case 'Adam': |
|
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.get('lr', 1e-3)) |
|
case 'AdamW': |
|
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.get('lr', 1e-3), weight_decay=self.hparams.get('optimizer_weight_decay', 0.01)) |
|
case 'SGD': |
|
optimizer = torch.optim.SGD(self.parameters(), lr=self.hparams.get('lr', 1e-3), momentum=self.hparams.get('optimizer_momentum', 0.9)) |
|
case 'Lion': |
|
|
|
optimizer = Lion(self.parameters(), lr=self.hparams.get('lr', 1e-4), weight_decay=self.hparams.get('optimizer_weight_decay', 0.1)) |
|
case _: |
|
raise ValueError(f"Optimizer {self.hparams.get('optimizer')} not supported") |
|
return optimizer |
|
|
|
def __step(self, batch, batch_idx, stage): |
|
inputs, targets = batch |
|
targets = torch.stack(targets[0]) |
|
outputs = self(inputs) |
|
loss, loss_per_head = self.model.calculate_loss(outputs, targets) |
|
loss_per_head_named = {f'{stage}_loss_{head}': loss_per_head[head] for head in loss_per_head} |
|
loss_proportions = { f'{stage}_loss_{head}_proportion': round(loss_per_head[head].item() / loss.item(), 2) for head in loss_per_head} |
|
loss_detail_dict = {**loss_per_head_named, **loss_proportions} |
|
|
|
if self.hparams.get('debug'): |
|
print(f"Step Inputs shape: {safe_shape(inputs)}") |
|
print(f"Step Targets shape: {safe_shape(targets)}") |
|
print(f"Step Outputs dict keys: {outputs.keys()}") |
|
|
|
|
|
self.log_dict({f'{stage}_loss': loss, **loss_detail_dict}, on_step=True, on_epoch=True, prog_bar=True, logger=True) |
|
return {'loss': loss, 'outputs': outputs} |
|
|
|
class LogConfusionMatrix(pl.Callback): |
|
def __init__(self, hparams, dataset_info_file, debug=False): |
|
super().__init__() |
|
|
|
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams" |
|
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} |
|
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None) |
|
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None) |
|
|
|
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True" |
|
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True" |
|
|
|
self.tiers = list(self.tiers_dict.keys()) |
|
self.phases = ['train', 'val', 'test'] |
|
self.modes = ['pixelwise', 'majority'] |
|
self.debug = debug |
|
|
|
if debug: |
|
print(f"Final head identified as: {self.final_head_name}") |
|
print(f"LogConfusionMatrix Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}") |
|
|
|
with open(dataset_info_file, 'r') as f: |
|
self.dataset_info = json.load(f) |
|
|
|
|
|
self.metrics_to_compute = ['confusion_matrix'] |
|
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases} |
|
|
|
def __init_metrics(self, tier, phase): |
|
num_classes = self.tiers_dict[tier]['num_classes_to_predict'] |
|
confusion_matrix = classification.MulticlassConfusionMatrix(num_classes=num_classes) |
|
|
|
return { |
|
'confusion_matrix': confusion_matrix |
|
} |
|
|
|
def setup(self, trainer, pl_module, stage=None): |
|
|
|
device = pl_module.device |
|
for phase_metrics in self.metrics.values(): |
|
for tier_metrics in phase_metrics.values(): |
|
for mode_metrics in tier_metrics.values(): |
|
for metric in self.metrics_to_compute: |
|
mode_metrics[metric].to(device) |
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'train') |
|
|
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'val') |
|
|
|
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
self.__update_confusion_matrices(trainer, pl_module, outputs, batch, batch_idx, 'test') |
|
|
|
def __update_confusion_matrices(self, trainer, pl_module, outputs, batch, batch_idx, phase): |
|
if trainer.sanity_checking: |
|
return |
|
|
|
targets = torch.stack(batch[1][0]) |
|
outputs = outputs['outputs'][self.final_head_name] |
|
field_ids = batch[1][1].permute(1, 0, 2, 3)[0] |
|
|
|
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info) |
|
|
|
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes): |
|
|
|
assert len(preds) == len(targets), f"Number of predictions and targets do not match: {len(preds)} vs {len(targets)}" |
|
assert len(preds) == len(self.tiers), f"Number of predictions and tiers do not match: {len(preds)} vs {len(self.tiers)}" |
|
|
|
for pred, target, tier in zip(preds, targets, self.tiers): |
|
if self.debug: |
|
print(f"Updating confusion matrix for {phase} {tier} {mode}") |
|
metrics = self.metrics[phase][tier][mode] |
|
|
|
if mode == 'majority': |
|
pred = pred[target != 0] |
|
target = target[target != 0] |
|
metrics['confusion_matrix'].update(pred, target) |
|
|
|
|
|
@staticmethod |
|
def get_pixelwise_and_majority_outputs(refinement_head_outputs, tiers, field_ids, dataset_info): |
|
""" |
|
Get the pixelwise and majority predictions from the model outputs. |
|
The pixelwise tier predictions are derived from the refinement_head_outputs predictions. |
|
The majority last tier predictions are derived from the refinement_head_outputs. And then the majority lower-tier predictions are derived from the majority highest-tier predictions. |
|
|
|
Also sets the background to 0 for all field majority predictions (regardless of what the model predicts for the background class). |
|
As this is a classification task and not a segmentation task and the field boundaries are known beforehand and not of any interest. |
|
|
|
Args: |
|
refinement_head_outputs (torch.Tensor(batch, C, H, W)): The probability outputs from the model for the refined tier. |
|
tiers (list of str): List of tiers e.g. ['tier1', 'tier2', 'tier3']. |
|
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction. |
|
dataset_info (dict): The dataset information. |
|
|
|
Returns: |
|
torch.Tensor(tiers, batch, H, W): The pixelwise predictions. |
|
torch.Tensor(tiers, batch, H, W): The majority predictions. |
|
""" |
|
|
|
|
|
highest_tier = tiers[-1] |
|
|
|
pixelwise_highest_tier = torch.softmax(refinement_head_outputs, dim=1).argmax(dim=1) |
|
majority_highest_tier = LogConfusionMatrix.get_field_majority_preds(refinement_head_outputs, field_ids) |
|
|
|
tier_mapping = {tier: dataset_info[f'{highest_tier}_to_{tier}'] for tier in tiers if tier != highest_tier} |
|
|
|
pixelwise_outputs = {highest_tier: pixelwise_highest_tier} |
|
majority_outputs = {highest_tier: majority_highest_tier} |
|
|
|
|
|
for tier in tiers: |
|
if tier != highest_tier: |
|
pixelwise_outputs[tier] = torch.zeros_like(pixelwise_highest_tier) |
|
majority_outputs[tier] = torch.zeros_like(majority_highest_tier) |
|
|
|
|
|
for i, mappings in enumerate(zip(*tier_mapping.values())): |
|
for j, tier in enumerate(tier_mapping.keys()): |
|
pixelwise_outputs[tier][pixelwise_highest_tier == i] = mappings[j] |
|
majority_outputs[tier][majority_highest_tier == i] = mappings[j] |
|
|
|
pixelwise_outputs_stacked = torch.stack([pixelwise_outputs[tier] for tier in tiers]) |
|
majority_outputs_stacked = torch.stack([majority_outputs[tier] for tier in tiers]) |
|
|
|
|
|
assert isinstance(pixelwise_outputs_stacked, torch.Tensor), "pixelwise_outputs_stacked is not a tensor" |
|
assert isinstance(majority_outputs_stacked, torch.Tensor), "majority_outputs_stacked is not a tensor" |
|
|
|
return pixelwise_outputs_stacked, majority_outputs_stacked |
|
|
|
|
|
@staticmethod |
|
def get_field_majority_preds(output, field_ids): |
|
""" |
|
Get the majority prediction for each field in the batch. The majority excludes the background class. |
|
|
|
Args: |
|
output (torch.Tensor(batch, C, H, W)): The probability outputs from the model (tier3_refined) |
|
field_ids (torch.Tensor(batch, H, W)): The field IDs for each prediction. |
|
|
|
Returns: |
|
torch.Tensor(batch, H, W): The majority predictions. |
|
""" |
|
|
|
pixelwise = torch.softmax(output[:, 1:, :, :], dim=1).argmax(dim=1) + 1 |
|
majority_preds = torch.zeros_like(pixelwise) |
|
for batch in range(len(pixelwise)): |
|
field_ids_batch = field_ids[batch] |
|
for field_id in np.unique(field_ids_batch.cpu().numpy()): |
|
if field_id == 0: |
|
continue |
|
field_mask = field_ids_batch == field_id |
|
flattened_pred = pixelwise[batch][field_mask].view(-1) |
|
flattened_pred = flattened_pred[flattened_pred != 0] |
|
if len(flattened_pred) == 0: |
|
continue |
|
mode_pred, _ = torch.mode(flattened_pred) |
|
majority_preds[batch][field_mask] = mode_pred.item() |
|
return majority_preds |
|
|
|
def on_train_epoch_end(self, trainer, pl_module): |
|
|
|
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'train') |
|
|
|
def on_validation_epoch_end(self, trainer, pl_module): |
|
|
|
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'val') |
|
|
|
def on_test_epoch_end(self, trainer, pl_module): |
|
|
|
self.__log_and_reset_confusion_matrices(trainer, pl_module, 'test') |
|
|
|
def __log_and_reset_confusion_matrices(self, trainer, pl_module, phase): |
|
if trainer.sanity_checking: |
|
return |
|
|
|
for tier in self.tiers: |
|
for mode in self.modes: |
|
metrics = self.metrics[phase][tier][mode] |
|
confusion_matrix = metrics['confusion_matrix'] |
|
if self.debug: |
|
print(f"Logging and resetting confusion matrix for {phase} {tier} Update count: {confusion_matrix._update_count}") |
|
matrix = confusion_matrix.compute() |
|
|
|
|
|
matrix = matrix.float() |
|
row_sums = matrix.sum(dim=1, keepdim=True) |
|
matrix_percent = matrix / row_sums |
|
|
|
|
|
row_sum_check = matrix_percent.sum(dim=1) |
|
valid_rows = ~torch.isnan(row_sum_check) |
|
if valid_rows.any(): |
|
assert torch.allclose(row_sum_check[valid_rows], torch.ones_like(row_sum_check[valid_rows]), atol=1e-2), "Percentages do not sum to 1 for some valid rows" |
|
|
|
|
|
sorted_indices = row_sums.squeeze().argsort(descending=True) |
|
matrix_percent = matrix_percent[sorted_indices, :] |
|
matrix_percent = matrix_percent[:, sorted_indices] |
|
class_labels = [self.dataset_info[tier][i] for i in sorted_indices] |
|
row_sums_sorted = row_sums[sorted_indices] |
|
|
|
|
|
zero_rows = (row_sums_sorted == 0).squeeze() |
|
|
|
fig, ax = plt.subplots(figsize=(matrix.size(0), matrix.size(0)), dpi=140) |
|
|
|
ax.matshow(matrix_percent.cpu().numpy(), cmap='viridis') |
|
|
|
ax.xaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(1) + 1))) |
|
ax.yaxis.set_major_locator(ticker.FixedLocator(range(matrix.size(0) + 1))) |
|
|
|
ax.set_xticklabels(class_labels + [''], rotation=45) |
|
ax.set_yticklabels(class_labels + ['']) |
|
|
|
|
|
y_labels = [f'{class_labels[i]} [n={int(row_sums_sorted[i].item()):,.0f}]'.replace(',', "'") for i in range(matrix.size(0))] |
|
ax.set_yticklabels(y_labels + ['']) |
|
|
|
ax.set_xlabel('Predictions') |
|
ax.set_ylabel('Targets') |
|
|
|
|
|
ax.xaxis.set_label_position('top') |
|
ax.xaxis.set_ticks_position('top') |
|
|
|
fig.tight_layout() |
|
|
|
for i in range(matrix.size(0)): |
|
for j in range(matrix.size(1)): |
|
if zero_rows[i]: |
|
ax.text(j, i, 'N/A', ha='center', va='center', color='black') |
|
else: |
|
ax.text(j, i, f'{matrix_percent[i, j]:.2f}', ha='center', va='center', color='#F88379', weight='bold') |
|
trainer.logger.experiment.log({f"{phase}_{tier}_confusion_matrix_{mode}": wandb.Image(fig)}) |
|
plt.close() |
|
confusion_matrix.reset() |
|
|
|
class LogMessisMetrics(pl.Callback): |
|
def __init__(self, hparams, dataset_info_file, debug=False): |
|
super().__init__() |
|
|
|
assert hparams.get('heads_spec') is not None, "heads_spec must be defined in the hparams" |
|
self.tiers_dict = {k: v for k, v in hparams.get('heads_spec').items() if v.get('is_metrics_tier', False)} |
|
self.last_tier_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_last_tier', False)), None) |
|
self.final_head_name = next((k for k, v in hparams.get('heads_spec').items() if v.get('is_final_head', False)), None) |
|
|
|
assert self.last_tier_name is not None, "No tier found with 'is_last_tier' set to True" |
|
assert self.final_head_name is not None, "No head found with 'is_final_head' set to True" |
|
|
|
self.tiers = list(self.tiers_dict.keys()) |
|
self.phases = ['train', 'val', 'test'] |
|
self.modes = ['pixelwise', 'majority'] |
|
self.debug = debug |
|
|
|
if debug: |
|
print(f"Last tier identified as: {self.last_tier_name}") |
|
print(f"Final head identified as: {self.final_head_name}") |
|
print(f"LogMessisMetrics Metrics over | Phases: {self.phases}, Tiers: {self.tiers}, Modes: {self.modes}") |
|
|
|
with open(dataset_info_file, 'r') as f: |
|
self.dataset_info = json.load(f) |
|
|
|
|
|
self.metrics_to_compute = ['accuracy', 'weighted_accuracy', 'precision', 'weighted_precision', 'recall', 'weighted_recall' ,'f1', 'weighted_f1', 'cohen_kappa'] |
|
self.metrics = {phase: {tier: {mode: self.__init_metrics(tier, phase) for mode in self.modes} for tier in self.tiers} for phase in self.phases} |
|
self.images_to_log = {phase: {mode: None for mode in self.modes} for phase in self.phases} |
|
self.images_to_log_targets = {phase: None for phase in self.phases} |
|
self.field_ids_to_log_targets = {phase: None for phase in self.phases} |
|
self.inputs_to_log = {phase: None for phase in self.phases} |
|
|
|
def __init_metrics(self, tier, phase): |
|
num_classes = self.tiers_dict[tier]['num_classes_to_predict'] |
|
|
|
accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='macro') |
|
weighted_accuracy = classification.MulticlassAccuracy(num_classes=num_classes, average='weighted') |
|
per_class_accuracies = { |
|
class_index: classification.BinaryAccuracy() for class_index in range(num_classes) |
|
} |
|
precision = classification.MulticlassPrecision(num_classes=num_classes, average='macro') |
|
weighted_precision = classification.MulticlassPrecision(num_classes=num_classes, average='weighted') |
|
recall = classification.MulticlassRecall(num_classes=num_classes, average='macro') |
|
weighted_recall = classification.MulticlassRecall(num_classes=num_classes, average='weighted') |
|
f1 = classification.MulticlassF1Score(num_classes=num_classes, average='macro') |
|
weighted_f1 = classification.MulticlassF1Score(num_classes=num_classes, average='weighted') |
|
cohen_kappa = classification.MulticlassCohenKappa(num_classes=num_classes) |
|
|
|
return { |
|
'accuracy': accuracy, |
|
'weighted_accuracy': weighted_accuracy, |
|
'per_class_accuracies': per_class_accuracies, |
|
'precision': precision, |
|
'weighted_precision': weighted_precision, |
|
'recall': recall, |
|
'weighted_recall': weighted_recall, |
|
'f1': f1, |
|
'weighted_f1': weighted_f1, |
|
'cohen_kappa': cohen_kappa |
|
} |
|
|
|
def setup(self, trainer, pl_module, stage=None): |
|
|
|
device = pl_module.device |
|
for phase_metrics in self.metrics.values(): |
|
for tier_metrics in phase_metrics.values(): |
|
for mode_metrics in tier_metrics.values(): |
|
for metric in self.metrics_to_compute: |
|
mode_metrics[metric].to(device) |
|
for class_accuracy in mode_metrics['per_class_accuracies'].values(): |
|
class_accuracy.to(device) |
|
|
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'train') |
|
|
|
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'val') |
|
|
|
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): |
|
self.__on_batch_end(trainer, pl_module, outputs, batch, batch_idx, 'test') |
|
|
|
def __on_batch_end(self, trainer: pl.Trainer, pl_module, outputs, batch, batch_idx, phase): |
|
if trainer.sanity_checking: |
|
return |
|
if self.debug: |
|
print(f"{phase} batch ended. Updating metrics...") |
|
|
|
targets = torch.stack(batch[1][0]) |
|
outputs = outputs['outputs'][self.final_head_name] |
|
field_ids = batch[1][1].permute(1, 0, 2, 3)[0] |
|
|
|
pixelwise_outputs, majority_outputs = LogConfusionMatrix.get_pixelwise_and_majority_outputs(outputs, self.tiers, field_ids, self.dataset_info) |
|
|
|
for preds, mode in zip([pixelwise_outputs, majority_outputs], self.modes): |
|
|
|
|
|
assert preds.shape == targets.shape, f"Shapes of predictions and targets do not match: {preds.shape} vs {targets.shape}" |
|
assert preds.shape[0] == len(self.tiers), f"Number of tiers in predictions and tiers do not match: {preds.shape[0]} vs {len(self.tiers)}" |
|
|
|
self.images_to_log[phase][mode] = preds[-1] |
|
|
|
for pred, target, tier in zip(preds, targets, self.tiers): |
|
|
|
if mode == 'majority': |
|
pred = pred[target != 0] |
|
target = target[target != 0] |
|
metrics = self.metrics[phase][tier][mode] |
|
for metric in self.metrics_to_compute: |
|
metrics[metric].update(pred, target) |
|
if self.debug: |
|
print(f"{phase} {tier} {mode} {metric} updated. Update count: {metrics[metric]._update_count}") |
|
self.__update_per_class_metrics(pred, target, metrics['per_class_accuracies']) |
|
|
|
self.images_to_log_targets[phase] = targets[-1] |
|
self.field_ids_to_log_targets[phase] = field_ids |
|
self.inputs_to_log[phase] = batch[0] |
|
|
|
def __update_per_class_metrics(self, preds, targets, per_class_accuracies): |
|
for class_index, class_accuracy in per_class_accuracies.items(): |
|
if not (targets == class_index).any(): |
|
continue |
|
|
|
if class_index == 0: |
|
|
|
class_mask = targets != 0 |
|
else: |
|
|
|
class_mask = targets == 0 |
|
|
|
preds_fields = preds[~class_mask] |
|
targets_fields = targets[~class_mask] |
|
|
|
|
|
preds_class = (preds_fields == class_index).float() |
|
targets_class = (targets_fields == class_index).float() |
|
|
|
class_accuracy.update(preds_class, targets_class) |
|
|
|
if self.debug: |
|
print(f"Shape of preds_fields: {preds_fields.shape}") |
|
print(f"Shape of targets_fields: {targets_fields.shape}") |
|
print(f"Unique values in preds_fields: {torch.unique(preds_fields)}") |
|
print(f"Unique values in targets_fields: {torch.unique(targets_fields)}") |
|
print(f"Per-class metrics for class {class_index} updated. Update count: {per_class_accuracies[class_index]._update_count}") |
|
|
|
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
|
self.__on_epoch_end(trainer, pl_module, 'train') |
|
|
|
def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
|
self.__on_epoch_end(trainer, pl_module, 'val') |
|
|
|
def on_test_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): |
|
self.__on_epoch_end(trainer, pl_module, 'test') |
|
|
|
def __on_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule, phase): |
|
if trainer.sanity_checking: |
|
return |
|
for tier in self.tiers: |
|
for mode in self.modes: |
|
metrics = self.metrics[phase][tier][mode] |
|
|
|
|
|
metrics_dict = {metric: metrics[metric].compute() for metric in self.metrics_to_compute} |
|
pl_module.log_dict({f"{phase}_{metric}_{tier}_{mode}": v for metric, v in metrics_dict.items()}, on_step=False, on_epoch=True) |
|
for metric in self.metrics_to_compute: |
|
metrics[metric].reset() |
|
|
|
|
|
|
|
|
|
|
|
class_metrics = [] |
|
class_names_mapping = self.dataset_info[tier.split('_')[0] if '_refined' in tier else tier] |
|
for class_index, class_accuracy in metrics['per_class_accuracies'].items(): |
|
if class_accuracy._update_count == 0: |
|
continue |
|
tp, tn, fp, fn = class_accuracy.tp, class_accuracy.tn, class_accuracy.fp, class_accuracy.fn |
|
recall = (tp / (tp + fn)).item() if tp + fn > 0 else 0 |
|
precision = (tp / (tp + fp)).item() if tp + fp > 0 else 0 |
|
f1 = (2 * (precision * recall) / (precision + recall)) if precision + recall > 0 else 0 |
|
n_of_class = (tp + fn).item() |
|
class_metrics.append([class_index, class_names_mapping[class_index], precision, recall, f1, class_accuracy.compute().item(), n_of_class]) |
|
class_accuracy.reset() |
|
wandb_table = wandb.Table(data=class_metrics, columns=["Class Index", "Class Name", "Precision", "Recall", "F1", "Accuracy", "N"]) |
|
trainer.logger.experiment.log({f"{phase}_per_class_metrics_{tier}_{mode}": wandb_table}) |
|
|
|
|
|
n_classes = max([ |
|
torch.max(self.images_to_log_targets[phase]), |
|
torch.max(self.images_to_log[phase]["majority"]), |
|
torch.max(self.images_to_log[phase]["pixelwise"]) |
|
]) |
|
images = [LogMessisMetrics.process_images(self.images_to_log[phase][mode], n_classes) for mode in self.modes] |
|
images.append(LogMessisMetrics.create_positive_negative_image(self.images_to_log[phase]["majority"], self.images_to_log_targets[phase])) |
|
images.append(LogMessisMetrics.process_images(self.images_to_log_targets[phase], n_classes)) |
|
images.append(LogMessisMetrics.process_images(self.field_ids_to_log_targets[phase].cpu())) |
|
|
|
examples = [] |
|
for i in range(len(images[0])): |
|
example = np.concatenate([img[i] for img in images], axis=0) |
|
examples.append(wandb.Image(example, caption=f"From Top to Bottom: {self.modes[0]}, {self.modes[1]}, right/wrong classifications, target, fields")) |
|
|
|
trainer.logger.experiment.log({f"{phase}_examples": examples}) |
|
|
|
|
|
batch_input_data = self.inputs_to_log[phase].cpu() |
|
ground_truth_masks = self.images_to_log_targets[phase].cpu().numpy() |
|
pixel_wise_masks = self.images_to_log[phase]["pixelwise"].cpu().numpy() |
|
field_majority_masks = self.images_to_log[phase]["majority"].cpu().numpy() |
|
correctness_masks = self.create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks) |
|
class_labels = {idx: name for idx, name in enumerate(self.dataset_info[self.last_tier_name])} |
|
|
|
segmentation_masks = [] |
|
for input_data, ground_truth_mask, pixel_wise_mask, field_majority_mask, correctness_mask in zip(batch_input_data, ground_truth_masks, pixel_wise_masks, field_majority_masks, correctness_masks): |
|
middle_timestep_index = input_data.shape[1] // 2 |
|
gamma = 2.5 |
|
rgb_image = input_data[:3, middle_timestep_index, :, :].permute(1, 2, 0).numpy() |
|
rgb_image = (rgb_image - rgb_image.min()) / (rgb_image.max() - rgb_image.min()) |
|
rgb_image = np.power(rgb_image, 1.0 / gamma) |
|
rgb_image = (rgb_image * 255).astype(np.uint8) |
|
|
|
mask_img = wandb.Image( |
|
rgb_image, |
|
masks={ |
|
"predictions_pixel_wise": {"mask_data": pixel_wise_mask, "class_labels": class_labels}, |
|
"predictions_field_majority": {"mask_data": field_majority_mask, "class_labels": class_labels}, |
|
"ground_truth": {"mask_data": ground_truth_mask, "class_labels": class_labels}, |
|
"correctness": {"mask_data": correctness_mask, "class_labels": { 0: "Background", 1: "Wrong", 2: "Right" }}, |
|
}, |
|
) |
|
segmentation_masks.append(mask_img) |
|
|
|
trainer.logger.experiment.log({f"{phase}_segmentation_mask": segmentation_masks}) |
|
|
|
if self.debug: |
|
print(f"{phase} epoch ended. Logging & resetting metrics...", trainer.sanity_checking) |
|
|
|
@staticmethod |
|
def create_positive_negative_segmentation_mask(field_majority_masks, ground_truth_masks): |
|
""" |
|
Create a tensor that shows the positive and negative classifications of the model. |
|
|
|
Args: |
|
field_majority_masks (np.ndarray): The field majority masks generated by the model. |
|
ground_truth_masks (np.ndarray): The ground truth masks. |
|
|
|
Returns: |
|
np.ndarray: An array with values: |
|
- 0 where the target is 0, |
|
- 2 where the prediction matches the target, |
|
- 1 where the prediction does not match the target. |
|
""" |
|
correctness_mask = np.zeros_like(ground_truth_masks, dtype=int) |
|
|
|
matches = (field_majority_masks == ground_truth_masks) & (ground_truth_masks != 0) |
|
correctness_mask[matches] = 2 |
|
|
|
mismatches = (field_majority_masks != ground_truth_masks) & (ground_truth_masks != 0) |
|
correctness_mask[mismatches] = 1 |
|
|
|
return correctness_mask |
|
|
|
@staticmethod |
|
def create_positive_negative_image(generated_images, target_images): |
|
""" |
|
Create an image that shows the positive and negative classifications of the model. |
|
|
|
Args: |
|
generated_images (torch.Tensor): The images generated by the model. |
|
target_images (torch.Tensor): The target images. |
|
|
|
Returns: |
|
list: A list of processed images. |
|
""" |
|
classification_masks = generated_images == target_images |
|
processed_imgs = [] |
|
for mask, target in zip(classification_masks, target_images): |
|
|
|
colored_img = torch.zeros((mask.shape[0], mask.shape[1], 3), dtype=torch.uint8) |
|
mask = mask.bool() |
|
colored_img[mask] = torch.tensor([0, 255, 0], dtype=torch.uint8) |
|
colored_img[~mask] = torch.tensor([255, 0, 0], dtype=torch.uint8) |
|
colored_img[target == 0] = torch.tensor([0, 0, 0], dtype=torch.uint8) |
|
processed_imgs.append(colored_img.cpu()) |
|
return processed_imgs |
|
|
|
@staticmethod |
|
def process_images(imgs, max=None): |
|
""" |
|
Process a batch of images to be logged on wandb. |
|
|
|
Args: |
|
imgs (torch.Tensor): A batch of images with shape (B, H, W) to be processed. |
|
max (float, optional): The maximum value to normalize the images. Defaults to None. If None, the maximum value in the batch is used. |
|
""" |
|
if max is None: |
|
max = np.max(imgs.cpu().numpy()) |
|
normalized_img = imgs / max |
|
processed_imgs = [] |
|
for img in normalized_img.cpu().numpy(): |
|
if max < 60: |
|
cmap = ListedColormap(plt.get_cmap('tab20').colors + plt.get_cmap('tab20b').colors + plt.get_cmap('tab20c').colors) |
|
else: |
|
cmap = plt.get_cmap('viridis') |
|
colored_img = cmap(img) |
|
colored_img[img == 0] = [0, 0, 0, 1] |
|
colored_img_uint8 = (colored_img[:, :, :3] * 255).astype(np.uint8) |
|
processed_imgs.append(colored_img_uint8) |
|
return processed_imgs |