File size: 2,030 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utils for GIVT stage I and II trainers."""

from typing import Any

import jax
import jax.numpy as jnp


def unbin_depth(
    depth: jax.Array,
    *,
    min_depth: float,
    max_depth: float,
    num_bins: int,
) -> jax.Array:
  """Transform a depth map with binned values into a float-valued depth map.

  Args:
    depth: Depth map whose binned values are encoded in one-hot fashion along
      the last dimension.
    min_depth: Minimum binned depth value.
    max_depth: Maximum value of binned depth.
    num_bins: Number of depth bins.

  Returns:
    Float-valued depth map.
  """
  depth = jnp.argmax(depth, axis=-1)
  depth = depth.astype(jnp.float32) + 0.5  # Undoes floor in expectation.
  depth /= num_bins
  return depth * (max_depth - min_depth) + min_depth


def get_local_rng(
    seed: int | jax.Array,
    batch: Any,
) -> jax.Array:
  """Generate a per-image seed based on the image id or the image values.
  
  Args:
    seed: Random seed from which per-image seeds should be derived.
    batch: Pytree containing a batch of images (key "image") and optionally
      image ids (key "image/id").

  Returns:
    Array containing per-image ids.
  """
  fake_id = None
  if "image" in batch:
    fake_id = (10**6 * jax.vmap(jnp.mean)(batch["image"])).astype(jnp.int32)
  return jax.lax.scan(
      lambda k, x: (jax.random.fold_in(k, x), None),
      jax.random.PRNGKey(seed),
      batch.get("image/id", fake_id),
  )[0]