|
from typing import Dict, List |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from skimage import segmentation |
|
|
|
|
|
def get_multi_view_consistency_loss(opt): |
|
loss = MultiViewConsistencyLoss( |
|
opt.mvc_soft, |
|
opt.mvc_zeros_on_au, |
|
opt.mvc_single_weight, |
|
opt.modality, |
|
opt.mvc_spixel, |
|
opt.mvc_num_spixel, |
|
) |
|
return loss |
|
|
|
|
|
class MultiViewConsistencyLoss(nn.Module): |
|
def __init__( |
|
self, |
|
soft: bool, |
|
zeros_on_au: bool, |
|
single_weight: Dict, |
|
modality: List, |
|
spixel: bool = False, |
|
num_spixel: int = 100, |
|
eps: float = 1e-4, |
|
): |
|
super().__init__() |
|
self.soft = soft |
|
self.zeros_on_au = zeros_on_au |
|
self.single_weight = single_weight |
|
self.modality = modality |
|
self.spixel = spixel |
|
self.num_spixel = num_spixel |
|
self.eps = eps |
|
|
|
self.mse_loss = nn.MSELoss(reduction="mean") |
|
|
|
def forward(self, output: Dict, label, spixel=None, image=None, mask=None): |
|
|
|
tgt_map = torch.zeros_like( |
|
output[self.modality[0]]["out_map"], requires_grad=False |
|
) |
|
with torch.no_grad(): |
|
for modality in self.modality: |
|
weight = self.single_weight[modality.lower()] |
|
tgt_map = tgt_map + weight * output[modality]["out_map"] |
|
|
|
if self.spixel: |
|
|
|
tgt_map = get_spixel_tgt_map(tgt_map, spixel) |
|
|
|
if not self.soft: |
|
for b in range(tgt_map.shape[0]): |
|
if tgt_map[b, ...].max() <= 0.5 and label[b] == 1.0: |
|
tgt_map[b, ...][ |
|
torch.where(tgt_map[b, ...] == torch.max(tgt_map[b, ...])) |
|
] = 1.0 |
|
tgt_map[torch.where(tgt_map > 0.5)] = 1 |
|
tgt_map[torch.where(tgt_map <= 0.5)] = 0 |
|
tgt_map[torch.where(label == 0.0)[0], ...] = 0.0 |
|
|
|
if self.zeros_on_au: |
|
tgt_map[torch.where(label == 0.0)[0], ...] = 0.0 |
|
|
|
total_loss = 0.0 |
|
loss_dict = {} |
|
for modality in self.modality: |
|
loss = self.mse_loss(output[modality]["out_map"], tgt_map) |
|
loss_dict[f"multi_view_consistency_loss_{modality}"] = loss |
|
total_loss = total_loss + loss |
|
|
|
return {**loss_dict, "tgt_map": tgt_map, "total_loss": total_loss} |
|
|
|
def _save( |
|
self, |
|
spixel: torch.Tensor, |
|
image: torch.Tensor, |
|
mask: torch.Tensor, |
|
tgt_map: torch.Tensor, |
|
raw_tgt_map: torch.Tensor, |
|
out_path: str = "tmp/spixel_tgt_map.png", |
|
): |
|
spixel = spixel.permute(0, 2, 3, 1).detach().cpu().numpy() |
|
image = image.permute(0, 2, 3, 1).detach().cpu().numpy() |
|
mask = mask.permute(0, 2, 3, 1).detach().cpu().numpy() * 255.0 |
|
tgt_map = tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0 |
|
raw_tgt_map = ( |
|
raw_tgt_map.permute(0, 2, 3, 1).squeeze(3).detach().cpu().numpy() * 255.0 |
|
) |
|
bn = spixel.shape[0] |
|
i = 1 |
|
for b in range(bn): |
|
plt.subplot(bn, 5, i) |
|
i += 1 |
|
plt.imshow(image[b]) |
|
plt.axis("off") |
|
plt.title("image") |
|
plt.subplot(bn, 5, i) |
|
i += 1 |
|
plt.imshow(mask[b]) |
|
plt.axis("off") |
|
plt.title("mask") |
|
plt.subplot(bn, 5, i) |
|
i += 1 |
|
plt.imshow(spixel[b]) |
|
plt.axis("off") |
|
plt.title("superpixel") |
|
plt.subplot(bn, 5, i) |
|
i += 1 |
|
plt.imshow(raw_tgt_map[b]) |
|
plt.axis("off") |
|
plt.title("raw target map") |
|
plt.subplot(bn, 5, i) |
|
i += 1 |
|
plt.imshow(tgt_map[b]) |
|
plt.axis("off") |
|
plt.title("target map") |
|
plt.tight_layout() |
|
plt.savefig(out_path, dpi=300) |
|
plt.close() |
|
|
|
|
|
def get_spixel_tgt_map(weighted_sum, spixel): |
|
b, _, h, w = weighted_sum.shape |
|
spixel_tgt_map = torch.zeros_like(weighted_sum, requires_grad=False) |
|
|
|
for bidx in range(b): |
|
spixel_indices = spixel[bidx, ...].unique() |
|
|
|
for spixel_idx in spixel_indices.tolist(): |
|
area = (spixel[bidx, ...] == spixel_idx).sum() |
|
weighted_sum_in_area = weighted_sum[bidx, ...][ |
|
torch.where(spixel[bidx, ...] == spixel_idx) |
|
].sum() |
|
avg_area = weighted_sum_in_area / area |
|
|
|
spixel_tgt_map[bidx][ |
|
torch.where(spixel[bidx, ...] == spixel_idx) |
|
] = avg_area |
|
|
|
return spixel_tgt_map |
|
|
|
|
|
if __name__ == "__main__": |
|
mvc_loss = MultiViewConsistencyLoss(True, True, [1, 1, 2]) |
|
print("a") |
|
|