File size: 5,666 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for retrieval."""

from unittest import mock

from big_vision.evaluators.proj.image_text import retrieval
from big_vision.pp import ops_general  # pylint: disable=unused-import
from big_vision.pp import ops_image  # pylint: disable=unused-import
from big_vision.pp import registry
import chex
import flax.linen as nn
import jax
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds


def _get_test_texts2labels():

  def pp(features):
    features["labels"] = tf.strings.to_number(features["texts"])
    return features

  return pp


def _get_copy_from(**key_map):

  def copy_from(d):
    d = dict(d)
    for k1, k2 in key_map.items():
      d[k1] = d[k2]
    return d

  return copy_from


class _Model(nn.Module):

  @nn.compact
  def __call__(self, image, texts):
    self.param("x", lambda _: 0.)

    def z(x):
      if x is not None:
        batch_size = len(x)
        # Note that the returned vector is most similar with other vectors
        # generated from the same underlying `x[:]`.
        x = jnp.concatenate([100 * jnp.ones([batch_size, 1]), x[:, None]],
                            axis=1)
        return x / jnp.linalg.norm(x, axis=1)[:, None]

    return z(image), z(texts), None


def setUpModule():
  chex.set_n_cpu_devices(8)


class RetrievalTest(tf.test.TestCase):

  def test_prepare_datasets(self):

    def generator():
      yield {
          "image": tf.ones([5, 5, 3], tf.float32),
          "captions": {
              "text": tf.constant(["11", "12"])
          }
      }
      yield {
          "image": tf.ones([4, 4, 3], tf.float32),
          "captions": {
              "text": tf.constant(["21", "22", "23"])
          }
      }

    ds = tf.data.Dataset.from_generator(
        generator,
        output_signature={
            "image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
            "captions": {
                "text": tf.TensorSpec(shape=[None], dtype=tf.string),
            },
        })
    with registry.temporary_ops(test_texts2labels=_get_test_texts2labels):
      ds_img, ds_txt = retrieval.prepare_datasets(
          ds,
          pp_img="resize(2)",
          pp_txt="test_texts2labels()",
          txt_name=("captions", "text"),
      )
    it_img = iter(ds_img)
    it_txt = iter(ds_txt)
    batch = next(it_img)
    self.assertAllEqual(batch["id"], 0)
    self.assertAllEqual(batch["image"], tf.ones([2, 2, 3]))
    batch = next(it_img)
    self.assertAllEqual(batch["id"], 1)
    self.assertAllEqual(batch["image"], tf.ones([2, 2, 3]))
    batch = next(it_txt)
    self.assertAllEqual(batch["id"], 0)
    self.assertAllEqual(batch["caption_i"], 0)
    self.assertAllEqual(batch["labels"], 11.0)
    batch = next(it_txt)
    self.assertAllEqual(batch["id"], 0)
    self.assertAllEqual(batch["caption_i"], 1)
    self.assertAllEqual(batch["labels"], 12.0)
    batch = next(it_txt)
    self.assertAllEqual(batch["id"], 1)
    self.assertAllEqual(batch["caption_i"], 0)
    self.assertAllEqual(batch["labels"], 21.0)
    batch = next(it_txt)
    self.assertAllEqual(batch["id"], 1)
    self.assertAllEqual(batch["caption_i"], 1)
    self.assertAllEqual(batch["labels"], 22.0)
    batch = next(it_txt)
    self.assertAllEqual(batch["id"], 1)
    self.assertAllEqual(batch["caption_i"], 2)
    self.assertAllEqual(batch["labels"], 23.0)

  def test_evaluate(self):
    per_device_batch_size = 2
    batch_size = per_device_batch_size * jax.device_count()
    num_examples = 1 * batch_size + 1
    splits = {
        "test":
            tfds.core.SplitInfo(
                name="test", shard_lengths=[num_examples], num_bytes=0)
    }

    model = _Model()
    params = model.init(jax.random.PRNGKey(0), None, None)["params"]

    with tfds.testing.mock_data(num_examples=num_examples):
      info_mock = mock.Mock()
      info_mock.splits = splits
      with mock.patch.object(retrieval, "_get_dataset_info",
                             lambda _: info_mock):
        with registry.temporary_ops(copy_from=_get_copy_from):
          evaluator = retrieval.Evaluator(
              lambda p, b: model.apply({"params": p},
                                       b.get("image", None),
                                       b.get("labels", None)),
              dataset="coco_captions",
              batch_size=batch_size,
              devices=jax.devices(),
              txt_name=("captions", "text"),
              pp_img="copy_from(image='id')",
              pp_txt="copy_from(labels='id')",
          )
      results = evaluator.evaluate(params)

    # Assert all examples were processed.
    self.assertLen(results["images"]["embeddings"], num_examples)
    self.assertLen(results["images"]["id"], num_examples)
    # Assert no padding was processed (expects exactly one (=first) image.id=0
    self.assertEqual((results["images"]["id"] == 0).sum(), 1)
    # Expect perfect ITR with above _Model()...
    self.assertEqual(results["img2txt"]["Recall@1"], 1.0)
    self.assertEqual(results["txt2img"]["Recall@5"], 1.0)


if __name__ == "__main__":
  tf.test.main()