pranavSIT's picture
added pali inference
74e8f2f
raw
history blame
13.9 kB
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for utils."""
from functools import partial
import os
from absl.testing import parameterized
from big_vision import utils
import chex
import flax
import jax
from jax.experimental.array_serialization import serialization as array_serial
import jax.numpy as jnp
import numpy as np
import tensorflow as tf
from tensorflow.io import gfile
NDEV = 4
def setUpModule():
chex.set_n_cpu_devices(NDEV)
class PadShardUnpadTest(chex.TestCase, tf.test.TestCase):
BATCH_SIZES = [NDEV, NDEV + 1, NDEV - 1, 5 * NDEV, 5 * NDEV + 1, 5 * NDEV - 1]
DTYPES = [np.float32, np.uint8, jax.numpy.bfloat16, np.int32]
def tearDown(self):
chex.clear_trace_counter()
super().tearDown()
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
def test_basics(self, dtype, bs):
# Just tests that basic calling works without exploring caveats.
@partial(utils.pad_shard_unpad, static_argnums=())
def add(a, b):
return a + b
x = np.arange(bs, dtype=dtype)
y = add(x, 10*x)
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x))
@parameterized.parameters(DTYPES)
def test_min_device_batch_avoids_recompile(self, dtype):
@partial(utils.pad_shard_unpad, static_argnums=())
@jax.jit
@chex.assert_max_traces(n=1)
def add(a, b):
return a + b
chex.clear_trace_counter()
for bs in self.BATCH_SIZES:
x = np.arange(bs, dtype=dtype)
y = add(x, 10*x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(x + 10*x))
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
def test_static_argnum(self, dtype, bs):
@partial(utils.pad_shard_unpad, static_argnums=(1,))
def add(a, b):
return a + b
x = np.arange(bs, dtype=dtype)
y = add(x, 10)
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(x + 10))
@parameterized.product(dtype=DTYPES, bs=BATCH_SIZES)
def test_static_argnames(self, dtype, bs):
# In this test, leave static_argnums at the default value too, in order to
# test the default/most canonical path where `params` are the first arg.
@partial(utils.pad_shard_unpad, static_argnames=('b',))
def add(params, a, *, b):
return params * a + b
x = np.arange(bs, dtype=dtype)
y = add(5, x, b=10)
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10))
class TreeTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.d1 = {'w1': 1, 'w2': 2, 'w34': (3, 4)}
self.d1_flat = [1, 2]
self.d1_flat_jax = jax.tree.flatten(self.d1)[0]
self.d1_named_flat = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)]
self.d1_named_flat_jax = [('w1', 1), ('w2', 2), ('w34/0', 3), ('w34/1', 4)]
self.d2 = {'conv1': {'kernel': 0, 'bias': 1},
'conv2': {'kernel': 2, 'bias': 3}}
self.d2_flat = [1, 0, 3, 2]
self.d2_flat_jax = jax.tree.flatten(self.d2)[0]
self.d2_named_flat = [('conv1/bias', 1), ('conv1/kernel', 0),
('conv2/bias', 3), ('conv2/kernel', 2)]
self.d2_named_flat_jax = [('conv1/bias', 1), ('conv1/kernel', 0),
('conv2/bias', 3), ('conv2/kernel', 2)]
self.d2_named_flat_inner = [
('conv1/bias', 1), ('conv1/kernel', 0), ('conv1', self.d2['conv1']),
('conv2/bias', 3), ('conv2/kernel', 2), ('conv2', self.d2['conv2']),
('', self.d2),
]
# This is a very important testcase that checks whether we correctly
# recover jax' traversal order, even though our custom traversal may not
# be consistent with jax' traversal order. In particular, jax traverses
# FlaxStruct in the order of attribute definition, while our custom
# traversal is alphabetical.
@flax.struct.dataclass
class FlaxStruct():
v3: float
v2: int
v1: str
self.d3 = {'a': 0, 'flax': FlaxStruct(2.0, 1, 's')}
self.d3_flat = [0, 1, 2.0, 's']
self.d3_flat_jax = jax.tree.flatten(self.d3)[0]
self.d3_named_flat = [
('a', 0), ('flax/v1', 's'), ('flax/v2', 1), ('flax/v3', 2.0)]
self.d3_named_flat_jax = [
('a', 0), ('flax/v3', 2.0), ('flax/v2', 1), ('flax/v1', 's')]
def test_traverse_with_names(self):
names_and_vals = list(utils._traverse_with_names(self.d1))
self.assertEqual(names_and_vals, self.d1_named_flat)
names_and_vals = list(utils._traverse_with_names(self.d2))
self.assertEqual(names_and_vals, self.d2_named_flat)
names_and_vals = list(utils._traverse_with_names(
self.d2, with_inner_nodes=True))
self.assertEqual(names_and_vals, self.d2_named_flat_inner)
names_and_vals = list(utils._traverse_with_names(self.d3))
self.assertEqual(names_and_vals, self.d3_named_flat)
def test_tree_flatten_with_names(self):
names_and_vals = utils.tree_flatten_with_names(self.d1)[0]
self.assertEqual(names_and_vals, self.d1_named_flat_jax)
self.assertEqual([x for _, x in names_and_vals], self.d1_flat_jax)
names_and_vals = utils.tree_flatten_with_names(self.d2)[0]
self.assertEqual(names_and_vals, self.d2_named_flat_jax)
self.assertEqual([x for _, x in names_and_vals], self.d2_flat_jax)
names_and_vals = utils.tree_flatten_with_names(self.d3)[0]
self.assertEqual(names_and_vals, self.d3_named_flat_jax)
self.assertEqual([x for _, x in names_and_vals], self.d3_flat_jax)
def test_tree_map_with_names(self):
d1 = utils.tree_map_with_names(
lambda name, x: -x if 'w2' in name else x, self.d1)
self.assertEqual(d1, {'w1': 1, 'w2': -2, 'w34': (3, 4)})
d1 = utils.tree_map_with_names(
lambda name, x1, x2: x1 + x2 if 'w2' in name else x1, self.d1, self.d1)
self.assertEqual(d1, {'w1': 1, 'w2': 4, 'w34': (3, 4)})
def test_recover_tree(self):
keys = ['a/b', 'a/c/x', 'a/c/y', 'd']
values = [0, 1, 2, 3]
self.assertEqual(utils.recover_tree(keys, values),
{'a': {'b': 0, 'c': {'x': 1, 'y': 2}}, 'd': 3})
def test_make_mask_trees(self):
F, T = False, True # pylint: disable=invalid-name
tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}}
msk1 = {'a': {'b': F, 'x': T}, 'b': {'x': T, 'y': F}}
msk2 = {'a': {'b': F, 'x': F}, 'b': {'x': F, 'y': T}}
# Note that 'b' matches '^b' only and not '.*/b'.
# Also note that "b/x" is matched by rule 1 only (because it comes first).
self.assertEqual(
utils.make_mask_trees(tree, ('.*/x', 'b/.*')), [msk1, msk2])
def test_tree_get(self):
tree = {'a': {'b': 0, 'x': 1}, 'b': {'x': 2, 'y': 3}}
self.assertEqual(utils.tree_get(tree, 'a/b'), 0)
self.assertEqual(utils.tree_get(tree, 'a/x'), 1)
self.assertEqual(utils.tree_get(tree, 'b/x'), 2)
self.assertEqual(utils.tree_get(tree, 'b/y'), 3)
self.assertEqual(utils.tree_get(tree, 'a'), tree['a'])
self.assertEqual(utils.tree_get(tree, 'b'), tree['b'])
def test_tree_replace(self):
tree = {'a': {'b': 2, 'c': 3}, 'c': 4}
replacements = {
'a/b': 'a/b/x', # replaces 'a/b' with 'a/b/x'
'.*c': 'C', # replaces 'c' with 'C' ('a/c' is removed)
'C': 'D', # replaces 'C' (which was 'c') with 'D'
'.*/c': None, # removes 'a/c'
}
tree2 = utils.tree_replace(tree, replacements)
self.assertEqual(tree2, {'D': 4, 'a': {'b': {'x': 2}}})
def test_tree_compare(self):
tree1_only, tree2_only, dtype_shape_mismatch = utils.tree_compare(
{'a': {'b': jnp.array(2), 'c': jnp.array(3)}},
{'a': {'B': jnp.array(2), 'c': jnp.array(3.)}},
)
self.assertEqual(tree1_only, {'a/b'})
self.assertEqual(tree2_only, {'a/B'})
self.assertEqual(
dtype_shape_mismatch,
{'a/c': [(jnp.dtype('int32'), ()), (jnp.dtype('float32'), ())]})
class StepConversionTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(
('nice_steps', 1000, None, None, dict(foo_steps=3), 3),
('nice_epochs', 1000, 100, None, dict(foo_epochs=3), 30),
('nice_examples', None, 100, None, dict(foo_examples=300), 3),
('nice_percent', None, None, 10, dict(foo_percent=0.30), 3),
('offbyone_steps', 1001, None, None, dict(foo_steps=3), 3),
('offbyone_epochs', 1001, 100, None, dict(foo_epochs=3), 30),
('offbyone_examples', None, 101, None, dict(foo_examples=300), 3),
('offbyone_percent', None, None, 11, dict(foo_percent=0.30), 3),
)
def test_steps(self, data_size, batch_size, total, cfg, expected):
# Correct default usage:
step = utils.steps('foo', cfg, data_size=data_size, batch_size=batch_size,
total_steps=total)
self.assertEqual(step, expected)
# Inexitent entry:
with self.assertRaises(ValueError):
step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size,
total_steps=total)
step = utils.steps('bar', cfg, data_size=data_size, batch_size=batch_size,
total_steps=total, default=1234)
self.assertEqual(step, 1234)
class CreateLearningRateScheduleTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.named_parameters(
('linear', 'linear', {}, 13, .5),
('polynomial', 'polynomial', {'end': .1, 'power': 2}, 13, .325),
('cosine', 'cosine', {}, 13, .5),
('rsqrt', 'rsqrt', {'timescale': 1}, 13, 0.3333333),
('stair_5', 'stair', {'steps': [10], 'mults': [.5]}, 5, 1.),
('stair_10', 'stair', {'steps': [10], 'mults': [.5]}, 10, .5),
('warmup_before', 'rsqrt', {'timescale': 1}, 3, .6),
('cooldown_after', 'rsqrt', {'timescale': 1}, 20, .05),
)
def test_schedule(self, decay_type, extra_kwargs, step, expected_lr):
lr_fn = utils.create_learning_rate_schedule(
total_steps=21,
batch_size=512,
base=.5,
decay_type=decay_type,
scale_with_batchsize=True,
warmup_steps=5,
cooldown_steps=5,
**extra_kwargs)
lr = lr_fn(step)
self.assertAlmostEqual(lr, expected_lr)
class CheckpointTest(tf.test.TestCase):
def setup(self):
gacm = array_serial.GlobalAsyncCheckpointManager()
save_path = os.path.join(self.create_tempdir('workdir'), 'checkpoint.bv')
x = utils.put_cpu(np.array([1, 2, 3, 4]))
y = utils.put_cpu(np.array([5, 6, 7, 8]))
ckpt = {'x': x, 'y': {'z': y}}
sharding = jax.sharding.SingleDeviceSharding(
jax.local_devices(backend='cpu')[0]
)
shardings = jax.tree.map(lambda _: sharding, ckpt)
return gacm, save_path, ckpt, shardings
def test_save_and_load(self):
gacm, save_path, ckpt, shardings = self.setup()
step = 100
utils.save_checkpoint_ts(gacm, ckpt, save_path, step, keep=True)
gacm.wait_until_finished()
ckpt_loaded = utils.load_checkpoint_ts(save_path,
tree=ckpt, shardings=shardings)
chex.assert_trees_all_equal(ckpt_loaded, ckpt)
save_path_step = f'{save_path}-{step:09d}'
ckpt_loaded_step = utils.tsload(save_path_step, shardings=shardings)
chex.assert_trees_all_equal(ckpt_loaded_step, ckpt)
def test_save_and_partial_load(self):
gacm, save_path, ckpt, shardings = self.setup()
utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100)
gacm.wait_until_finished()
_ = shardings.pop('x'), ckpt.pop('x')
ckpt_loaded = utils.load_checkpoint_ts(save_path,
tree=ckpt, shardings=shardings)
chex.assert_trees_all_equal(ckpt_loaded, ckpt)
def test_save_and_cpu_load(self):
gacm, save_path, ckpt, _ = self.setup()
utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100)
gacm.wait_until_finished()
ckpt_loaded = utils.load_checkpoint_ts(save_path)
chex.assert_trees_all_equal(ckpt_loaded, ckpt)
def test_save_and_partial_cpu_load(self):
gacm, save_path, ckpt, _ = self.setup()
utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100)
gacm.wait_until_finished()
ckpt.pop('y')
ckpt_loaded = utils.load_checkpoint_ts(save_path, regex='x.*')
chex.assert_trees_all_equal(ckpt_loaded, ckpt)
def test_keep_deletes(self):
def x(tree, factor): # x as in "times" for multiplying.
return jax.tree.map(lambda a: a * factor, tree)
gacm, save_path, ckpt, _ = self.setup()
utils.save_checkpoint_ts(gacm, ckpt, save_path, step=100, keep=False)
utils.save_checkpoint_ts(gacm, x(ckpt, 2), save_path, step=200, keep=True)
utils.save_checkpoint_ts(gacm, x(ckpt, 3), save_path, step=300, keep=False)
gacm.wait_until_finished()
ckpt_loaded_200 = utils.tsload(f'{save_path}-{200:09d}')
chex.assert_trees_all_equal(ckpt_loaded_200, x(ckpt, 2))
ckpt_loaded_300 = utils.tsload(f'{save_path}-{300:09d}-tmp')
chex.assert_trees_all_equal(ckpt_loaded_300, x(ckpt, 3))
ckpt_loaded_last = utils.load_checkpoint_ts(save_path)
chex.assert_trees_all_equal(ckpt_loaded_last, x(ckpt, 3))
with self.assertRaises(Exception): # Can different types depending on fs.
_ = utils.tsload(f'{save_path}-{100:09d}')
# Test that ckpt@100 was deleted
self.assertFalse(gfile.exists(f'{save_path}-{100:09d}-tmp'))
if __name__ == '__main__':
tf.test.main()