File size: 2,683 Bytes
4187c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import torch

from .metrics import PixelAccuracy, MeanObservableIOU, MeanUnobservableIOU, ObservableIOU, UnobservableIOU, mAP

from .loss import EnhancedLoss

from .segmentation_head import SegmentationHead

from . import get_model
from .base import BaseModel
from .bev_projection import CartesianProjection, PolarProjectionDepth
from .schema import ModelConfiguration

class MapPerceptionNet(BaseModel):

    def _init(self, conf: ModelConfiguration):
        self.image_encoder = get_model(
            conf.image_encoder.name
        )(conf.image_encoder.backbone)

        self.decoder = SegmentationHead(
            in_channels=conf.latent_dim, n_classes=conf.num_classes)

        ppm = conf.pixel_per_meter
        self.projection_polar = PolarProjectionDepth(
            conf.z_max,
            ppm,
            conf.scale_range,
            conf.z_min,
        )
        self.projection_bev = CartesianProjection(
            conf.z_max, conf.x_max, ppm, conf.z_min
        )

        self.scale_classifier = torch.nn.Linear(
            conf.latent_dim, conf.num_scale_bins
        )  # l4 - working

        self.num_classes = conf.num_classes

        self.loss_fn = EnhancedLoss(conf.loss)

    def _forward(self, data):
        f_image, camera = self.image_encoder(data)

        scales = self.scale_classifier(
            f_image.moveaxis(1, -1))
        f_polar = self.projection_polar(f_image, scales, camera)

        # Map to the BEV.
        f_bev, valid_bev, _ = self.projection_bev(
            f_polar.float(), None, camera.float()
        )

        output = self.decoder(f_bev[..., :-1])

        probs = torch.nn.functional.sigmoid(output)

        return {
            "output": probs,
            "logits": output,
            "scales": scales,
            "features_image": f_image,
            "features_bev": f_bev,
            "valid_bev": valid_bev.squeeze(1),
        }

    def loss(self, pred, data):
        loss = self.loss_fn(pred, data)
        return loss

    def metrics(self):
        m = {
            "pix_acc": PixelAccuracy(),
            "map": mAP(self.num_classes),
            "miou_observable": MeanObservableIOU(self.num_classes),
            "miou_non_observable": MeanUnobservableIOU(self.num_classes),
        }
        m.update(
            {
                f"IoU_observable_class_{i}": ObservableIOU(i, num_classes=self.num_classes)
                for i in range(self.num_classes)
            }
        )
        m.update(
            {
                f"IoU_non_observable_{i}": UnobservableIOU(i, num_classes=self.num_classes)
                for i in range(self.num_classes)
            }
        )
        return m