File size: 5,056 Bytes
74e8f2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Load and run the PaliGemma model."""
import functools
import sys

from absl import app
from absl import flags
from absl import logging

# pylint: disable=all
import jax
from jax.sharding import Mesh, NamedSharding, PartitionSpec
import ml_collections
import numpy as np

import big_vision.models.proj.paligemma.gemma_bv
import big_vision.models.proj.paligemma.paligemma as model_mod
import big_vision.models.vit
import big_vision.pp.builder
import big_vision.pp.tokenizer
import big_vision.pp.ops_image
import big_vision.pp.ops_general
import big_vision.pp.ops_text
import big_vision.pp.proj.paligemma.ops
import big_vision.sharding
import big_vision.trainers.proj.paligemma.predict_fns
import big_vision.utils as u
# pylint: enable=all

# We always want to be explicit about any host-device transfers.
jax.config.update("jax_transfer_guard", "disallow")

CKPT = flags.DEFINE_string(
    "ckpt", default=None, help="Path to checkpoint.")
IMAGE = flags.DEFINE_string(
    "image", default=None, help="Path to input image.")

SAMPLER = flags.DEFINE_string(
    "sampler", default="greedy", help="Decoding strategy. Try `nucleus(0.1)`")
RES = flags.DEFINE_integer(
    "res", default=224, help="Image resolution (224, 448, 896).")
MAX_DECODE_LEN = flags.DEFINE_integer(
    "max_decode_len", default=128, help="Max total generation steps.")
PREFILL_LEN = flags.DEFINE_integer(
    "prefill_len", default=32, help="Size of prefill (prompt). "
    "Shorter is faster, but too short will cut off your prompt.")

TOKENIZER = "gemma(tokensets=['loc', 'seg'])"


def load_model(ckpt):
  model_cfg = ml_collections.FrozenConfigDict(dict(
      img=dict(variant="So400m/14", pool_type="none", scan=True),
      llm=dict(vocab_size=256_000 + 1024 + 128),
  ))
  model = model_mod.Model(**model_cfg)
  params = model_mod.load(None, ckpt, model_cfg)
  return model, params


def info(s, *a):
  logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a)
  logging.flush()


def main(argv):
  info(f"{argv=}")
  info("Loading model...")
  model, params = load_model(CKPT.value)

  predict_fns = big_vision.trainers.proj.paligemma.predict_fns.get_all(model)

  info("Loading tokenizer...")
  tokzr = big_vision.pp.tokenizer.get_tokenizer(TOKENIZER)

  info("Creating mesh and sharding params...")
  mesh = Mesh(jax.devices(), ("data"))
  repl_sharding = NamedSharding(mesh, PartitionSpec())
  data_sharding = NamedSharding(mesh, PartitionSpec("data"))
  params_sharding = big_vision.sharding.infer_sharding(
      params, strategy=[(".*", "fsdp(axis='data')")], mesh=mesh)

  # Ship the params to device(s)
  params = jax.tree.map(lambda x, sh: u.reshard(x, sh), params, params_sharding)

  # Mostly go through pp ops to build our batch:
  pp_fn = big_vision.pp.builder.get_preprocess_fn("|".join([
      f"decode|resize({RES.value})|value_range(-1, 1)",
      f"tok(key='prefix', bos='yes', model={repr(TOKENIZER)})",
      f"tok(key='septok', text='\\n', model={repr(TOKENIZER)})",
      'masked_concat(["prefix", "septok"], mask_ar=[0, 0], mask_input=[1, 1])',
      f'tolen({PREFILL_LEN.value}, pad_value=0, key="text")',
      f'tolen({PREFILL_LEN.value}, pad_value=1, key="mask_ar")',
      f'tolen({PREFILL_LEN.value}, pad_value=0, key="mask_input")',
      'keep("image", "text", "mask_ar", "mask_input")',
  ]), log_data=False)

  decode = functools.partial(
      predict_fns["decode"], devices=jax.devices(),
      eos_token=tokzr.eos_token, max_decode_len=MAX_DECODE_LEN.value,
      sampler=SAMPLER.value)

  def make_batch(fname, prompt):
    image = open(fname, "rb").read()

    # Create an example
    example = pp_fn({"image": image, "prefix": np.array(prompt)})
    example["_mask"] = np.array(True)  # True means valid non-pad example

    batch = jax.tree.map(lambda x: x[None], example)
    return u.reshard(batch, repl_sharding)  # Move to device(s)

  info("Precompiling inference function...")
  decode({"params": params}, batch=make_batch(IMAGE.value, "caption en"))

  info("Type a prompt and press enter, for example 'caption en': ")
  for line in map(str.strip, sys.stdin):
    tokens = decode({"params": params}, batch=make_batch(IMAGE.value, line))
    tokens = jax.device_get(tokens)[0]  # First batch entry.

    # TODO: b/lbeyer - flip around: output on stdout, logs on stderr.
    print(tokzr.to_str(tokens), file=sys.stderr, flush=True)


if __name__ == "__main__":
  flags.mark_flag_as_required("ckpt")
  flags.mark_flag_as_required("image")
  app.run(main)