|
"""Wraps `big_vision` PaliGemma model for easy use in demo.""" |
|
|
|
from collections.abc import Callable |
|
import dataclasses |
|
from typing import Any |
|
|
|
import jax |
|
import jax.numpy as jnp |
|
import ml_collections |
|
import numpy as np |
|
import PIL.Image |
|
|
|
from big_vision import sharding |
|
from big_vision import utils |
|
from big_vision.models.proj.paligemma import paligemma |
|
from big_vision.pp import builder as pp_builder |
|
from big_vision.pp import ops_general |
|
from big_vision.pp import ops_image |
|
from big_vision.pp import ops_text |
|
from big_vision.pp import tokenizer |
|
from big_vision.pp.proj.paligemma import ops as ops_paligemma |
|
from big_vision.trainers.proj.paligemma import predict_fns |
|
|
|
|
|
mesh = jax.sharding.Mesh(jax.devices(), 'data') |
|
|
|
|
|
def _recover_bf16(x): |
|
if x.dtype == np.dtype('V2'): |
|
x = x.view('bfloat16') |
|
return x |
|
|
|
|
|
def _load( |
|
path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152 |
|
): |
|
"""Loads model, params, decode functions and tokenizer.""" |
|
tok = tokenizer.get_tokenizer(tokenizer_spec) |
|
|
|
config = ml_collections.FrozenConfigDict(dict( |
|
llm_model='proj.paligemma.gemma_bv', |
|
llm=dict(vocab_size=vocab_size, variant='gemma_2b'), |
|
img=dict(variant='So400m/14', pool_type='none', scan=True), |
|
)) |
|
model = paligemma.Model(**config) |
|
decode = predict_fns.get_all(model)['decode'] |
|
beam_decode = predict_fns.get_all(model)['beam_decode'] |
|
|
|
params_cpu = paligemma.load(None, path, config) |
|
|
|
params_cpu = jax.tree.map(_recover_bf16, params_cpu) |
|
|
|
return model, params_cpu, decode, beam_decode, tok |
|
|
|
|
|
def _shard_params(params_cpu): |
|
"""Shards `params_cpu` with fsdp strategy on all available devices.""" |
|
params_sharding = sharding.infer_sharding( |
|
params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh |
|
) |
|
params = jax.tree.map(utils.reshard, params_cpu, params_sharding) |
|
return params |
|
|
|
|
|
def _pil2np(img): |
|
"""Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`.""" |
|
if isinstance(img, PIL.Image.Image): |
|
img = np.array(img) |
|
img = img[..., :3] |
|
if img.ndim == 2: |
|
img = img[..., None] |
|
if img.shape[-1] == 1: |
|
img = np.repeat(img, 3, axis=-1) |
|
return img |
|
|
|
|
|
def _prepare_batch( |
|
images, |
|
prefixes, |
|
*, |
|
res=224, |
|
tokenizer_spec='gemma(tokensets=("loc", "seg"))', |
|
suffixes=None, |
|
text_len=64, |
|
): |
|
"""Returns non-sharded batch.""" |
|
|
|
pp_fn = pp_builder.get_preprocess_fn('|'.join([ |
|
f'resize({res}, antialias=True)|value_range(-1, 1)', |
|
f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')", |
|
f"tok(key='septok', text='\\n', model='{tokenizer_spec}')", |
|
f"tok(key='suffix', model='{tokenizer_spec}')", |
|
'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', |
|
f'tolen({text_len}, pad_value=0, key="text")', |
|
f'tolen({text_len}, pad_value=1, key="mask_ar")', |
|
f'tolen({text_len}, pad_value=0, key="mask_input")', |
|
'keep("image", "text", "mask_ar", "mask_input")', |
|
]), log_data=False) |
|
assert not isinstance(prefixes, str), f'expected batch: {prefixes}' |
|
assert ( |
|
isinstance(images, (list, tuple)) or images.ndim == 4 |
|
), f'expected batch: {images.shape}' |
|
if suffixes is None: |
|
suffixes = [''] * len(prefixes) |
|
assert len(prefixes) == len(suffixes) == len(images) |
|
examples = [{'_mask': True, **pp_fn({ |
|
'image': np.asarray(_pil2np(image)), |
|
'prefix': np.array(prefix), |
|
'suffix': np.array(suffix), |
|
})} for image, prefix, suffix in zip(images, prefixes, suffixes)] |
|
batch = jax.tree_map(lambda *xs: np.stack(xs), *examples) |
|
return batch |
|
|
|
|
|
def _shard_batch(batch, n=None): |
|
"""Shards `batch` with fsdp strategy on all available devices.""" |
|
if n is None: |
|
n = jax.local_device_count() |
|
def pad(x): |
|
return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1)) |
|
batch = {k: pad(v) for k, v in batch.items()} |
|
data_sharding = jax.sharding.NamedSharding( |
|
mesh, jax.sharding.PartitionSpec('data') |
|
) |
|
batch_on_device = utils.reshard(batch, data_sharding) |
|
return batch_on_device |
|
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True, order=True) |
|
class PaligemmaConfig: |
|
"""Desribes a `big_vision` PaliGemma model.""" |
|
|
|
ckpt: str |
|
res: int |
|
text_len: int |
|
tokenizer: str |
|
vocab_size: int |
|
|
|
|
|
@dataclasses.dataclass(frozen=True, kw_only=True) |
|
class PaliGemmaModel: |
|
"""Wraps a `big_vision` PaliGemma model.""" |
|
|
|
config: PaligemmaConfig |
|
tokenizer: tokenizer.Tokenizer |
|
decode: Callable[..., Any] |
|
beam_decode: Callable[..., Any] |
|
|
|
@classmethod |
|
def shard_batch(cls, batch): |
|
return _shard_batch(batch) |
|
|
|
@classmethod |
|
def shard_params(cls, params_cpu): |
|
return _shard_params(params_cpu) |
|
|
|
def prepare_batch(self, images, texts, suffixes=None): |
|
return _prepare_batch( |
|
images=images, |
|
prefixes=texts, |
|
suffixes=suffixes, |
|
res=self.config.res, |
|
tokenizer_spec=self.config.tokenizer, |
|
text_len=self.config.text_len, |
|
) |
|
|
|
def predict( |
|
self, |
|
params, |
|
batch, |
|
devices=None, |
|
max_decode_len=128, |
|
sampler='greedy', |
|
**kw, |
|
): |
|
"""Returns tokens.""" |
|
if devices is None: |
|
devices = jax.devices() |
|
if sampler == 'beam': |
|
decode = self.beam_decode |
|
else: |
|
decode = self.decode |
|
kw['sampler'] = sampler |
|
return decode( |
|
{'params': params}, |
|
batch=batch, |
|
devices=devices, |
|
eos_token=self.tokenizer.eos_token, |
|
max_decode_len=max_decode_len, |
|
**kw, |
|
) |
|
|
|
|
|
ParamsCpu = Any |
|
|
|
|
|
def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]: |
|
"""Loads model from config.""" |
|
model, params_cpu, decode, beam_decode, tok = _load( |
|
path=config.ckpt, |
|
tokenizer_spec=config.tokenizer, |
|
vocab_size=config.vocab_size, |
|
) |
|
del model |
|
return PaliGemmaModel( |
|
config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode, |
|
), params_cpu |
|
|