Spaces:
Runtime error
Runtime error
File size: 1,182 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
from mmengine.model import BaseTTAModel
from mmpretrain.registry import MODELS
from mmpretrain.structures import DataSample
@MODELS.register_module()
class AverageClsScoreTTA(BaseTTAModel):
def merge_preds(
self,
data_samples_list: List[List[DataSample]],
) -> List[DataSample]:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[List[DataSample]]): List of predictions
of all enhanced data.
Returns:
List[DataSample]: Merged prediction.
"""
merged_data_samples = []
for data_samples in data_samples_list:
merged_data_samples.append(self._merge_single_sample(data_samples))
return merged_data_samples
def _merge_single_sample(self, data_samples):
merged_data_sample: DataSample = data_samples[0].new()
merged_score = sum(data_sample.pred_score
for data_sample in data_samples) / len(data_samples)
merged_data_sample.set_pred_score(merged_score)
return merged_data_sample
|