|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Tests for retrieval.""" |
|
|
|
from unittest import mock |
|
|
|
from big_vision.evaluators.proj.image_text import retrieval |
|
from big_vision.pp import ops_general |
|
from big_vision.pp import ops_image |
|
from big_vision.pp import registry |
|
import chex |
|
import flax.linen as nn |
|
import jax |
|
import jax.numpy as jnp |
|
import tensorflow as tf |
|
import tensorflow_datasets as tfds |
|
|
|
|
|
def _get_test_texts2labels(): |
|
|
|
def pp(features): |
|
features["labels"] = tf.strings.to_number(features["texts"]) |
|
return features |
|
|
|
return pp |
|
|
|
|
|
def _get_copy_from(**key_map): |
|
|
|
def copy_from(d): |
|
d = dict(d) |
|
for k1, k2 in key_map.items(): |
|
d[k1] = d[k2] |
|
return d |
|
|
|
return copy_from |
|
|
|
|
|
class _Model(nn.Module): |
|
|
|
@nn.compact |
|
def __call__(self, image, texts): |
|
self.param("x", lambda _: 0.) |
|
|
|
def z(x): |
|
if x is not None: |
|
batch_size = len(x) |
|
|
|
|
|
x = jnp.concatenate([100 * jnp.ones([batch_size, 1]), x[:, None]], |
|
axis=1) |
|
return x / jnp.linalg.norm(x, axis=1)[:, None] |
|
|
|
return z(image), z(texts), None |
|
|
|
|
|
def setUpModule(): |
|
chex.set_n_cpu_devices(8) |
|
|
|
|
|
class RetrievalTest(tf.test.TestCase): |
|
|
|
def test_prepare_datasets(self): |
|
|
|
def generator(): |
|
yield { |
|
"image": tf.ones([5, 5, 3], tf.float32), |
|
"captions": { |
|
"text": tf.constant(["11", "12"]) |
|
} |
|
} |
|
yield { |
|
"image": tf.ones([4, 4, 3], tf.float32), |
|
"captions": { |
|
"text": tf.constant(["21", "22", "23"]) |
|
} |
|
} |
|
|
|
ds = tf.data.Dataset.from_generator( |
|
generator, |
|
output_signature={ |
|
"image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32), |
|
"captions": { |
|
"text": tf.TensorSpec(shape=[None], dtype=tf.string), |
|
}, |
|
}) |
|
with registry.temporary_ops(test_texts2labels=_get_test_texts2labels): |
|
ds_img, ds_txt = retrieval.prepare_datasets( |
|
ds, |
|
pp_img="resize(2)", |
|
pp_txt="test_texts2labels()", |
|
txt_name=("captions", "text"), |
|
) |
|
it_img = iter(ds_img) |
|
it_txt = iter(ds_txt) |
|
batch = next(it_img) |
|
self.assertAllEqual(batch["id"], 0) |
|
self.assertAllEqual(batch["image"], tf.ones([2, 2, 3])) |
|
batch = next(it_img) |
|
self.assertAllEqual(batch["id"], 1) |
|
self.assertAllEqual(batch["image"], tf.ones([2, 2, 3])) |
|
batch = next(it_txt) |
|
self.assertAllEqual(batch["id"], 0) |
|
self.assertAllEqual(batch["caption_i"], 0) |
|
self.assertAllEqual(batch["labels"], 11.0) |
|
batch = next(it_txt) |
|
self.assertAllEqual(batch["id"], 0) |
|
self.assertAllEqual(batch["caption_i"], 1) |
|
self.assertAllEqual(batch["labels"], 12.0) |
|
batch = next(it_txt) |
|
self.assertAllEqual(batch["id"], 1) |
|
self.assertAllEqual(batch["caption_i"], 0) |
|
self.assertAllEqual(batch["labels"], 21.0) |
|
batch = next(it_txt) |
|
self.assertAllEqual(batch["id"], 1) |
|
self.assertAllEqual(batch["caption_i"], 1) |
|
self.assertAllEqual(batch["labels"], 22.0) |
|
batch = next(it_txt) |
|
self.assertAllEqual(batch["id"], 1) |
|
self.assertAllEqual(batch["caption_i"], 2) |
|
self.assertAllEqual(batch["labels"], 23.0) |
|
|
|
def test_evaluate(self): |
|
per_device_batch_size = 2 |
|
batch_size = per_device_batch_size * jax.device_count() |
|
num_examples = 1 * batch_size + 1 |
|
splits = { |
|
"test": |
|
tfds.core.SplitInfo( |
|
name="test", shard_lengths=[num_examples], num_bytes=0) |
|
} |
|
|
|
model = _Model() |
|
params = model.init(jax.random.PRNGKey(0), None, None)["params"] |
|
|
|
with tfds.testing.mock_data(num_examples=num_examples): |
|
info_mock = mock.Mock() |
|
info_mock.splits = splits |
|
with mock.patch.object(retrieval, "_get_dataset_info", |
|
lambda _: info_mock): |
|
with registry.temporary_ops(copy_from=_get_copy_from): |
|
evaluator = retrieval.Evaluator( |
|
lambda p, b: model.apply({"params": p}, |
|
b.get("image", None), |
|
b.get("labels", None)), |
|
dataset="coco_captions", |
|
batch_size=batch_size, |
|
devices=jax.devices(), |
|
txt_name=("captions", "text"), |
|
pp_img="copy_from(image='id')", |
|
pp_txt="copy_from(labels='id')", |
|
) |
|
results = evaluator.evaluate(params) |
|
|
|
|
|
self.assertLen(results["images"]["embeddings"], num_examples) |
|
self.assertLen(results["images"]["id"], num_examples) |
|
|
|
self.assertEqual((results["images"]["id"] == 0).sum(), 1) |
|
|
|
self.assertEqual(results["img2txt"]["Recall@1"], 1.0) |
|
self.assertEqual(results["txt2img"]["Recall@5"], 1.0) |
|
|
|
|
|
if __name__ == "__main__": |
|
tf.test.main() |
|
|