File size: 1,714 Bytes
1c3eb47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, Callable
import torch.distributed as dist
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.utilities.distributed import gather_all_tensors
from mmpl.registry import METRICS


@METRICS.register_module(force=True)
class PLMeanAveragePrecision(MeanAveragePrecision):
    def __init__(
            self,
            *args,
            **kwargs,
    ) -> None:
        super().__init__(*args, **kwargs)

    def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
        super()._sync_dist(dist_sync_fn=dist_sync_fn, process_group=process_group)

        if self.iou_type == "segm":
            self.detections = self._gather_tuple_list(self.detections, process_group)
            self.groundtruths = self._gather_tuple_list(self.groundtruths, process_group)

    @staticmethod
    def _gather_tuple_list(list_to_gather: List[Tuple], process_group: Optional[Any] = None) -> List[Any]:
        world_size = dist.get_world_size(group=process_group)
        list_gathered = [None] * world_size
        dist.all_gather_object(list_gathered, list_to_gather, group=process_group)

        for rank in range(1, world_size):
            assert (
                len(list_gathered[rank]) == list_gathered[0],
                f"Rank{rank} doesn't have the same number of elements as Rank0: "
                f"{list_gathered[rank]} vs. {list_gathered[0]}",
            )
        list_merged = []
        for idx in range(len(list_gathered[0])):
            for rank in range(world_size):
                list_merged.append(list_gathered[rank][idx])

        return list_merged