|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Inputs, outputs and losses for depth prediction task.""" |
|
import big_vision.utils as u |
|
import einops |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
ONE_HOT_AXIS = -2 |
|
|
|
|
|
def input_pp(batch, config): |
|
"""Makes inputs for depth prediction task.""" |
|
if "labels" not in batch: |
|
x = None |
|
else: |
|
hp, wp = config.model.patch_size |
|
depth = batch["labels"][..., 0] |
|
|
|
|
|
nbins = config.model.inputs.depth[ONE_HOT_AXIS] |
|
mind = config.min_depth |
|
maxd = config.max_depth |
|
depth = (depth - mind) / (maxd - mind) |
|
depth *= nbins |
|
depth = jnp.floor(depth).astype(jnp.int32) |
|
depth = jnp.minimum(depth, nbins - 1) |
|
depth = jnp.maximum(depth, 0) |
|
|
|
|
|
depth = jax.nn.one_hot( |
|
einops.rearrange( |
|
depth, "b (hn hp) (wn wp) -> b (hn wn) (hp wp)", hp=hp, wp=wp), |
|
num_classes=config.model.inputs.depth[ONE_HOT_AXIS], |
|
axis=ONE_HOT_AXIS) |
|
x = {"depth": depth} |
|
ctx = batch.get("image_ctx", batch.get("image", None)) |
|
return {"ctx": ctx, "x": x} |
|
|
|
|
|
def loss_fn(predictions, batch, config): |
|
"""Computes loss for depth prediction task.""" |
|
labels = input_pp(batch, config)["x"] |
|
losses = {} |
|
loss = u.softmax_xent( |
|
logits=predictions["depth"], labels=labels["depth"], reduction=False, |
|
axis=ONE_HOT_AXIS) |
|
|
|
|
|
|
|
mask = jnp.argmax(labels["depth"], ONE_HOT_AXIS) != 0 |
|
loss = loss * mask |
|
losses["loss_depth"] = loss |
|
return sum(losses.values()), losses |
|
|
|
|
|
def predict_outputs(predictions, config): |
|
"""Makes outputs for depth predictin tasks.""" |
|
|
|
hp, wp = config.model.patch_size |
|
hn, wn = np.array(config.model.input_size) // np.array((hp, wp)) |
|
depth = einops.rearrange( |
|
predictions["depth"], |
|
"b (hn wn) c (hp wp) -> b (hn hp) (wn wp) c", |
|
hn=hn, wn=wn, hp=hp, wp=wp) |
|
|
|
depth = jnp.argmax(depth, axis=-1) |
|
|
|
|
|
nbins = config.model.inputs.depth[ONE_HOT_AXIS] |
|
mind = config.min_depth |
|
maxd = config.max_depth |
|
depth = depth.astype(jnp.float32) + 0.5 |
|
depth /= nbins |
|
depth = depth * (maxd - mind) + mind |
|
|
|
return {"depth": depth} |
|
|