pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for discriminative_classifier."""
from unittest import mock
from big_vision.evaluators.proj.image_text import discriminative_classifier
from big_vision.pp import ops_general # pylint: disable=unused-import
from big_vision.pp import ops_image # pylint: disable=unused-import
from big_vision.pp.registry import Registry
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
@Registry.register("preprocess_ops.test_texts2labels")
def _get_test_texts2labels():
def pp(features):
features["labels"] = tf.strings.to_number(features["texts"])
return features
return pp
@Registry.register("preprocess_ops.copy_from")
def _get_copy_from(**key_map):
def copy_from(d):
d = dict(d)
for k1, k2 in key_map.items():
d[k1] = d[k2]
return d
return copy_from
class _Model(nn.Module):
@nn.compact
def __call__(self, image, texts):
self.param("x", lambda _: 0.)
def z(x):
if x is not None:
# Note that the returned vector is most similar with other vectors
# generated from the same underlying `x[:]`.
return jnp.stack([jnp.cos(x / 10.), jnp.sin(x / 10.)]).T
if texts is not None:
texts %= 5 # For testing `pre_filter_fn` below.
return z(image), z(texts), None
class DiscriminativeClassifierTest(tf.test.TestCase):
def test_prepare_datasets(self):
def generator():
yield {
"image": tf.ones([5, 5, 3], tf.float32),
"label": 1,
}
yield {
"image": tf.ones([4, 4, 3], tf.float32),
"label": 2,
}
ds = tf.data.Dataset.from_generator(
generator,
output_signature={
"image": tf.TensorSpec(shape=[None, None, 3], dtype=tf.float32),
"label": tf.TensorSpec(shape=[], dtype=tf.int64),
})
class_names = [
"class1,class1a",
"class2",
]
prompt_templates = [
"test {}",
"test {} test",
]
ds_img, ds_txt = discriminative_classifier.prepare_datasets(
ds,
class_names,
prompt_templates=prompt_templates,
pp_img="resize(2)",
pp_txt="copy_from(labels='texts')",
)
it_img = iter(ds_img)
batch = next(it_img)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
batch = next(it_img)
self.assertAllEqual(2, batch["label"])
self.assertAllEqual(tf.ones([2, 2, 3]), batch["image"])
it_txt = iter(ds_txt)
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1 test", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1a", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(0, batch["label"])
self.assertAllEqual("test class1a test", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual("test class2", batch["labels"])
batch = next(it_txt)
self.assertAllEqual(1, batch["label"])
self.assertAllEqual("test class2 test", batch["labels"])
def test_average_embeddings(self):
self.assertAllEqual(jnp.array([
[2.], [4.], [8.],
]), discriminative_classifier._average_embeddings(
embeddings=jnp.array([
1., 3., 3., 1., # label1
8., 0., # label2
32., 0., 0., 0., # label3
])[..., None],
labels=jnp.array([
0, 0, # label1
0, 0, # label1 (alias)
1, 1, # label2
2, 2, # label3
2, 2, # label3 (alias)
], jnp.int32),
num_classes=3, normalize=False))
self.assertAllEqual(
jnp.array([
[2**-.5, 2**-.5],
]),
discriminative_classifier._average_embeddings(
embeddings=jnp.array([[2., 2.]]),
labels=jnp.array([0], jnp.int32),
num_classes=1,
normalize=True))
@mock.patch("big_vision.evaluators.proj."
"image_text.prompt_engineering.get_class_names")
@mock.patch("big_vision.evaluators.proj."
"image_text.prompt_engineering.get_prompt_templates")
@mock.patch("big_vision.evaluators.proj."
"image_text.discriminative_classifier._get_dataset_info")
def test_evaluate(self, get_dataset_info_mock, get_prompt_templates_mock,
get_class_names_mock):
per_device_batch_size = 10 # Make sure we have some unfiltered examples.
global_batch_size = per_device_batch_size * jax.device_count()
per_host_num_examples = int(
np.ceil(global_batch_size / jax.process_count()))
splits = {
"test":
tfds.core.SplitInfo(
name="test", shard_lengths=[per_host_num_examples], num_bytes=0)
}
model = _Model()
params = model.init(jax.random.PRNGKey(0), None, None)["params"]
prompt_templates = [
"test prompt 1 {}",
"test prompt 2 {}",
]
class_names = [
f"test_class_{i}" for i in range(10)
]
get_prompt_templates_mock.return_value = prompt_templates
get_class_names_mock.return_value = class_names
get_dataset_info_mock.return_value.splits = splits
def pre_filter_fn(features):
return features["label"] < 5 # matches `texts %= 5` above
dataset_name = "cifar10_test"
with tfds.testing.mock_data(num_examples=per_host_num_examples):
evaluator = discriminative_classifier.Evaluator(
lambda p, b: model.apply({"params": p},
b.get("image", None),
b.get("labels", None)),
dataset_names=[dataset_name],
prompt_templates="test_prompts",
batch_size=global_batch_size,
devices=jax.devices(),
pp_img="copy_from(image='label')",
pp_txt="copy_from(labels='label')",
dataset_overrides={
dataset_name: {
"dataset_name": "cifar10",
"class_names": "test_classes",
"pre_filter_fn": pre_filter_fn,
}
},
first_class_name_only=True,
)
results = evaluator.evaluate(
params,
dataset_name,
return_embeddings=True)
metrics = dict(evaluator.run(params))
# Assert all examples were processed.
self.assertLen(results["texts"]["embedding"],
len(class_names) * len(prompt_templates))
self.assertLen(results["texts"]["average_embedding"], len(class_names))
self.assertAllEqual(
sorted(results["texts"]["label"]),
[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9])
# Note that above model makes perfect predictions by design.
self.assertEqual(1.0, results["accuracy"])
self.assertEqual(1.0, metrics[f"{dataset_name}_accuracy"])
if __name__ == "__main__":
tf.test.main()