Spaces:
Running
Running
File size: 2,986 Bytes
629144d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import argparse
from pathlib import Path
from typing import Optional, Tuple
from omegaconf import OmegaConf, DictConfig
from .. import logger
from ..conf import data as conf_data_dir
from ..data import MapillaryDataModule
from .run import evaluate
split_overrides = {
"val": {
"scenes": [
"sanfrancisco_soma",
"sanfrancisco_hayes",
"amsterdam",
"berlin",
"lemans",
"montrouge",
"toulouse",
"nantes",
"vilnius",
"avignon",
"helsinki",
"milan",
"paris",
],
},
}
data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml")
data_cfg = OmegaConf.merge(
data_cfg_train,
{
"return_gps": True,
"add_map_mask": True,
"max_init_error": 32,
"loading": {"val": {"batch_size": 1, "num_workers": 0}},
},
)
default_cfg_single = OmegaConf.create({"data": data_cfg})
default_cfg_sequential = OmegaConf.create(
{
**default_cfg_single,
"chunking": {
"max_length": 10,
},
}
)
def run(
split: str,
experiment: str,
cfg: Optional[DictConfig] = None,
sequential: bool = False,
thresholds: Tuple[int] = (1, 3, 5),
**kwargs,
):
cfg = cfg or {}
if isinstance(cfg, dict):
cfg = OmegaConf.create(cfg)
default = default_cfg_sequential if sequential else default_cfg_single
default = OmegaConf.merge(default, split_overrides[split])
cfg = OmegaConf.merge(default, cfg)
dataset = MapillaryDataModule(cfg.get("data", {}))
metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs)
keys = [
"xy_max_error",
"xy_gps_error",
"yaw_max_error",
]
if sequential:
keys += [
"xy_seq_error",
"xy_gps_seq_error",
"yaw_seq_error",
"yaw_gps_seq_error",
]
for k in keys:
if k not in metrics:
logger.warning("Key %s not in metrics.", k)
continue
rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
return metrics
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--experiment", type=str, required=True)
parser.add_argument("--split", type=str, default="val", choices=["val"])
parser.add_argument("--sequential", action="store_true")
parser.add_argument("--output_dir", type=Path)
parser.add_argument("--num", type=int)
parser.add_argument("dotlist", nargs="*")
args = parser.parse_args()
cfg = OmegaConf.from_cli(args.dotlist)
run(
args.split,
args.experiment,
cfg,
args.sequential,
output_dir=args.output_dir,
num=args.num,
)
|