File size: 2,165 Bytes
9665c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213aff4
 
 
9665c2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Meta Platforms, Inc. and affiliates.

import numpy as np
import torch


def chunk_sequence(
    data,
    indices,
    *,
    names=None,
    max_length=100,
    min_length=1,
    max_delay_s=None,
    max_inter_dist=None,
    max_total_dist=None,
):
    sort_array = data.get("capture_time", data.get("index"))
    if sort_array is None:
        sort_array = indices if names is None else names
    indices = sorted(indices, key=lambda i: sort_array[i].tolist())
    centers = torch.stack([data["t_c2w"][i][:2] for i in indices]).numpy()
    dists = np.linalg.norm(np.diff(centers, axis=0), axis=-1)
    if "capture_time" in data:
        times = torch.stack([data["capture_time"][i] for i in indices])
        times = times.double() / 1e3  # ms to s
        delays = np.diff(times, axis=0)
    else:
        delays = np.zeros_like(dists)
    chunks = [[indices[0]]]
    dist_total = 0
    for dist, delay, idx in zip(dists, delays, indices[1:]):
        dist_total += dist
        if (
            (max_inter_dist is not None and dist > max_inter_dist)
            or (max_total_dist is not None and dist_total > max_total_dist)
            or (max_delay_s is not None and delay > max_delay_s)
            or len(chunks[-1]) >= max_length
        ):
            chunks.append([])
            dist_total = 0
        chunks[-1].append(idx)
    chunks = list(filter(lambda c: len(c) >= min_length, chunks))
    chunks = sorted(chunks, key=len, reverse=True)
    return chunks


def unpack_batches(batches):
    images = [b["image"].permute(1, 2, 0) for b in batches]
    canvas = [b["canvas"] for b in batches]
    rasters = [b["map"] for b in batches]
    yaws = torch.stack([b["roll_pitch_yaw"][-1] for b in batches])
    uv_gt = torch.stack([b["uv"] for b in batches])
    xy_gt = torch.stack(
        [canv.to_xy(uv.cpu().double()) for uv, canv in zip(uv_gt, canvas)]
    )
    ret = [images, canvas, rasters, yaws, uv_gt, xy_gt.to(uv_gt)]
    if "uv_gps" in batches[0]:
        xy_gps = torch.stack(
            [c.to_xy(b["uv_gps"].cpu().double()) for b, c in zip(batches, canvas)]
        )
        ret.append(xy_gps.to(uv_gt))
    return ret