File size: 1,865 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Perspective fields decoder heads.

Adapted from https://github.com/jinlinyi/PerspectiveFields
"""

import logging

from siclib.models import get_model
from siclib.models.base_model import BaseModel

logger = logging.getLogger(__name__)

# flake8: noqa
# mypy: ignore-errors


class PerspectiveDecoder(BaseModel):
    default_conf = {
        "up_decoder": {"name": "decoders.up_decoder"},
        "latitude_decoder": {"name": "decoders.latitude_decoder"},
    }

    required_data_keys = ["features"]

    def _init(self, conf):
        logger.debug(f"Initializing PerspectiveDecoder with config: {conf}")
        self.use_up = conf.up_decoder is not None
        self.use_latitude = conf.latitude_decoder is not None

        if self.use_up:
            self.up_head = get_model(conf.up_decoder.name)(conf.up_decoder)

        if self.use_latitude:
            self.latitude_head = get_model(conf.latitude_decoder.name)(conf.latitude_decoder)

    def _forward(self, data):
        out_up = self.up_head(data) if self.use_up else {}
        out_lat = self.latitude_head(data) if self.use_latitude else {}
        return out_up | out_lat

    def loss(self, pred, data):
        ref = data["up_field"] if self.use_up else data["latitude_field"]

        total = ref.new_zeros(ref.shape[0])
        losses, metrics = {}, {}
        if self.use_up:
            up_losses, up_metrics = self.up_head.loss(pred, data)
            losses |= up_losses
            metrics |= up_metrics
            total = total + losses.get("up_total", 0)

        if self.use_latitude:
            latitude_losses, latitude_metrics = self.latitude_head.loss(pred, data)
            losses |= latitude_losses
            metrics |= latitude_metrics
            total = total + losses.get("latitude_total", 0)

        losses["perspective_total"] = total
        return losses, metrics