Spaces:
Running
Running
File size: 2,683 Bytes
4187c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from .metrics import PixelAccuracy, MeanObservableIOU, MeanUnobservableIOU, ObservableIOU, UnobservableIOU, mAP
from .loss import EnhancedLoss
from .segmentation_head import SegmentationHead
from . import get_model
from .base import BaseModel
from .bev_projection import CartesianProjection, PolarProjectionDepth
from .schema import ModelConfiguration
class MapPerceptionNet(BaseModel):
def _init(self, conf: ModelConfiguration):
self.image_encoder = get_model(
conf.image_encoder.name
)(conf.image_encoder.backbone)
self.decoder = SegmentationHead(
in_channels=conf.latent_dim, n_classes=conf.num_classes)
ppm = conf.pixel_per_meter
self.projection_polar = PolarProjectionDepth(
conf.z_max,
ppm,
conf.scale_range,
conf.z_min,
)
self.projection_bev = CartesianProjection(
conf.z_max, conf.x_max, ppm, conf.z_min
)
self.scale_classifier = torch.nn.Linear(
conf.latent_dim, conf.num_scale_bins
) # l4 - working
self.num_classes = conf.num_classes
self.loss_fn = EnhancedLoss(conf.loss)
def _forward(self, data):
f_image, camera = self.image_encoder(data)
scales = self.scale_classifier(
f_image.moveaxis(1, -1))
f_polar = self.projection_polar(f_image, scales, camera)
# Map to the BEV.
f_bev, valid_bev, _ = self.projection_bev(
f_polar.float(), None, camera.float()
)
output = self.decoder(f_bev[..., :-1])
probs = torch.nn.functional.sigmoid(output)
return {
"output": probs,
"logits": output,
"scales": scales,
"features_image": f_image,
"features_bev": f_bev,
"valid_bev": valid_bev.squeeze(1),
}
def loss(self, pred, data):
loss = self.loss_fn(pred, data)
return loss
def metrics(self):
m = {
"pix_acc": PixelAccuracy(),
"map": mAP(self.num_classes),
"miou_observable": MeanObservableIOU(self.num_classes),
"miou_non_observable": MeanUnobservableIOU(self.num_classes),
}
m.update(
{
f"IoU_observable_class_{i}": ObservableIOU(i, num_classes=self.num_classes)
for i in range(self.num_classes)
}
)
m.update(
{
f"IoU_non_observable_{i}": UnobservableIOU(i, num_classes=self.num_classes)
for i in range(self.num_classes)
}
)
return m
|