Spaces:
Runtime error
Runtime error
File size: 1,703 Bytes
2ae34e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmengine.model import BaseTTAModel
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList
@MODELS.register_module()
class SegTTAModel(BaseTTAModel):
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[SampleList]): List of predictions
of all enhanced data.
Returns:
SampleList: Merged prediction.
"""
predictions = []
for data_samples in data_samples_list:
seg_logits = data_samples[0].seg_logits.data
logits = torch.zeros(seg_logits.shape).to(seg_logits)
for data_sample in data_samples:
seg_logit = data_sample.seg_logits.data
if self.module.out_channels > 1:
logits += seg_logit.softmax(dim=0)
else:
logits += seg_logit.sigmoid()
logits /= len(data_samples)
if self.module.out_channels == 1:
seg_pred = (logits > self.module.decode_head.threshold
).to(logits).squeeze(1)
else:
seg_pred = logits.argmax(dim=0)
data_sample = SegDataSample(
**{
'pred_sem_seg': PixelData(data=seg_pred),
'gt_sem_seg': data_samples[0].gt_sem_seg
})
predictions.append(data_sample)
return predictions
|