pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Load and run the PaliGemma model."""
import functools
import sys
from absl import app
from absl import flags
from absl import logging
# pylint: disable=all
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import ml_collections
import numpy as np
import big_vision.models.proj.paligemma.gemma_bv
import big_vision.models.proj.paligemma.paligemma as model_mod
import big_vision.models.vit
import big_vision.pp.builder
import big_vision.pp.tokenizer
import big_vision.pp.ops_image
import big_vision.pp.ops_general
import big_vision.pp.ops_text
import big_vision.pp.proj.paligemma.ops
import big_vision.sharding
import big_vision.trainers.proj.paligemma.predict_fns
import big_vision.utils as u
# pylint: enable=all
# We always want to be explicit about any host-device transfers.
jax.config.update("jax_transfer_guard", "disallow")
CKPT = flags.DEFINE_string(
"ckpt", default=None, help="Path to checkpoint.")
IMAGE = flags.DEFINE_string(
"image", default=None, help="Path to input image.")
SAMPLER = flags.DEFINE_string(
"sampler", default="greedy", help="Decoding strategy. Try `nucleus(0.1)`")
RES = flags.DEFINE_integer(
"res", default=224, help="Image resolution (224, 448, 896).")
MAX_DECODE_LEN = flags.DEFINE_integer(
"max_decode_len", default=128, help="Max total generation steps.")
PREFILL_LEN = flags.DEFINE_integer(
"prefill_len", default=32, help="Size of prefill (prompt). "
"Shorter is faster, but too short will cut off your prompt.")
TOKENIZER = "gemma(tokensets=['loc', 'seg'])"
def load_model(ckpt):
model_cfg = ml_collections.FrozenConfigDict(dict(
img=dict(variant="So400m/14", pool_type="none", scan=True),
llm=dict(vocab_size=256_000 + 1024 + 128),
))
model = model_mod.Model(**model_cfg)
params = model_mod.load(None, ckpt, model_cfg)
return model, params
def info(s, *a):
logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
logging.flush()
def main(argv):
info(f"{argv=}")
info("Loading model...")
model, params = load_model(CKPT.value)
predict_fns = big_vision.trainers.proj.paligemma.predict_fns.get_all(model)
info("Loading tokenizer...")
tokzr = big_vision.pp.tokenizer.get_tokenizer(TOKENIZER)
info("Creating mesh and sharding params...")
mesh = Mesh(jax.devices(), ("data"))
repl_sharding = NamedSharding(mesh, PartitionSpec())
data_sharding = NamedSharding(mesh, PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[(".*", "fsdp(axis='data')")], mesh=mesh)
# Ship the params to device(s)
params = jax.tree.map(lambda x, sh: u.reshard(x, sh), params, params_sharding)
# Mostly go through pp ops to build our batch:
pp_fn = big_vision.pp.builder.get_preprocess_fn("|".join([
f"decode|resize({RES.value})|value_range(-1, 1)",
f"tok(key='prefix', bos='yes', model={repr(TOKENIZER)})",
f"tok(key='septok', text='\\n', model={repr(TOKENIZER)})",
'masked_concat(["prefix", "septok"], mask_ar=[0, 0], mask_input=[1, 1])',
f'tolen({PREFILL_LEN.value}, pad_value=0, key="text")',
f'tolen({PREFILL_LEN.value}, pad_value=1, key="mask_ar")',
f'tolen({PREFILL_LEN.value}, pad_value=0, key="mask_input")',
'keep("image", "text", "mask_ar", "mask_input")',
]), log_data=False)
decode = functools.partial(
predict_fns["decode"], devices=jax.devices(),
eos_token=tokzr.eos_token, max_decode_len=MAX_DECODE_LEN.value,
sampler=SAMPLER.value)
def make_batch(fname, prompt):
image = open(fname, "rb").read()
# Create an example
example = pp_fn({"image": image, "prefix": np.array(prompt)})
example["_mask"] = np.array(True) # True means valid non-pad example
batch = jax.tree.map(lambda x: x[None], example)
return u.reshard(batch, repl_sharding) # Move to device(s)
info("Precompiling inference function...")
decode({"params": params}, batch=make_batch(IMAGE.value, "caption en"))
info("Type a prompt and press enter, for example 'caption en': ")
for line in map(str.strip, sys.stdin):
tokens = decode({"params": params}, batch=make_batch(IMAGE.value, line))
tokens = jax.device_get(tokens)[0] # First batch entry.
# TODO: b/lbeyer - flip around: output on stdout, logs on stderr.
print(tokzr.to_str(tokens), file=sys.stderr, flush=True)
if __name__ == "__main__":
flags.mark_flag_as_required("ckpt")
flags.mark_flag_as_required("image")
app.run(main)