diff --git a/README.md b/README.md
index d2f62d2727715ca8e7d17d08bed4288cd9b679f0..5a90284608d2a9eb0793f793ca939bb7de2fcb37 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,49 @@
 ---
-title: Paligemma
-emoji: 🌍
+title: PaliGemma Demo
+emoji: 🤲
 colorFrom: green
-colorTo: gray
+colorTo: yellow
 sdk: gradio
-sdk_version: 4.31.2
+sdk_version: 4.22.0
 app_file: app.py
 pinned: false
-license: gemma
+license: apache-2.0
 ---
 
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# PaliGemma Demo
+
+See [Blogpost] and [`big_vision README.md`] for details about the model.
+
+
+[Blogpost]: https://huggingface.co/blog/paligemma
+
+[`big_vision README.md`]: https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md
+
+## Development
+
+Local testing (CPU, Python 3.12):
+
+```bash
+pip -m venv env
+. env/bin/activate
+pip install -qr requirements-cpu.txt
+python app.py
+```
+
+Environment variables:
+
+- `MOCK_MODEL=yes`: For quick UI testing.
+- `RAM_CACHE_GB=18`: Enables caching of 3 bf16 models in memory: a single bf16
+  model is about 5860 MB. Use with care on spaces with little RAM. For example,
+  on a `A10G large` space you can cache five models in RAM, so you would set
+  `RAM_CACHE_GB=30`.
+- `HOST_COLOCATION=4`: If host RAM/disk is shared between 4 processes (e.g. the
+  Huggingface `A10 large` Spaces).
+
+
+Loading models:
+
+- The set of models loaded is defined in `./models.py`.
+- You must first acknowledge usage conditions to access models.
+- When testing locally, you'll have to run `huggingface_cli login`.
+- When running in a Huggingface Space, you'll have to set a `HF_TOKEN` secret.
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..63adb4ba1e3f9f4d95781e5d836b95fb6464a8c9
--- /dev/null
+++ b/app.py
@@ -0,0 +1,251 @@
+"""PaliGemma demo gradio app."""
+
+import datetime
+import functools
+import glob
+import json
+import logging
+import os
+import time
+
+import gradio as gr
+import jax
+import PIL.Image
+import gradio_helpers
+import models
+import paligemma_parse
+
+INTRO_TEXT = """🤲 PaliGemma demo\n\n
+| [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) 
+| [HF blog post](https://huggingface.co/blog/paligemma) 
+| [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024)
+| [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363) 
+| [Demo](https://huggingface.co/spaces/google/paligemma) 
+|\n\n
+[PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google, 
+inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and 
+built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343) 
+vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile 
+model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question 
+answering, text reading, object detection and object segmentation.
+\n\n
+This space includes models fine-tuned on a mix of downstream tasks. 
+See the [blog post](https://huggingface.co/blog/paligemma) and 
+[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md) 
+for detailed information how to use and fine-tune PaliGemma models.
+\n\n
+**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
+"""
+
+
+make_image = lambda value, visible: gr.Image(
+    value, label='Image', type='filepath', visible=visible)
+make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image')
+make_highlighted_text = functools.partial(gr.HighlightedText, label='Output')
+
+
+# https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1
+COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
+
+
+@gradio_helpers.synced
+def compute(image, prompt, model_name, sampler):
+  """Runs model inference."""
+  if image is None:
+    raise gr.Error('Image required')
+
+  logging.info('prompt="%s"', prompt)
+
+  if isinstance(image, str):
+    image = PIL.Image.open(image)
+  if gradio_helpers.should_mock():
+    logging.warning('Mocking response')
+    time.sleep(2.)
+    output = paligemma_parse.EXAMPLE_STRING
+  else:
+    if not model_name:
+      raise gr.Error('Models not loaded yet')
+    output = models.generate(model_name, sampler, image, prompt)
+    logging.info('output="%s"', output)
+
+  width, height = image.size
+  objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True)
+  labels = set(obj.get('name') for obj in objs if obj.get('name'))
+  color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
+  highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
+  annotated_image = (
+      image,
+      [
+          (
+              obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
+              obj['name'] or '',
+          )
+          for obj in objs
+          if 'mask' in obj or 'xyxy' in obj
+      ],
+  )
+  has_annotations = bool(annotated_image[1])
+  return (
+      make_highlighted_text(
+          highlighted_text, visible=True, color_map=color_map),
+      make_image(image, visible=not has_annotations),
+      make_annotated_image(
+          annotated_image, visible=has_annotations, width=width, height=height,
+          color_map=color_map),
+  )
+
+
+def warmup(model_name):
+  image = PIL.Image.new('RGB', [1, 1])
+  _ = compute(image, '', model_name, 'greedy')
+
+
+def reset():
+  return (
+      '', make_highlighted_text('', visible=False),
+      make_image(None, visible=True), make_annotated_image(None, visible=False),
+  )
+
+
+def create_app():
+  """Creates demo UI."""
+
+  make_model = lambda choices: gr.Dropdown(
+      value=(choices + [''])[0],
+      choices=choices,
+      label='Model',
+      visible=bool(choices),
+  )
+  make_prompt = lambda value, visible=True: gr.Textbox(
+      value, label='Prompt', visible=visible)
+
+  with gr.Blocks() as demo:
+
+    ##### Main UI structure.
+
+    gr.Markdown(INTRO_TEXT)
+    with gr.Row():
+      image = make_image(None, visible=True)  # input
+      annotated_image = make_annotated_image(None, visible=False)  # output
+      with gr.Column():
+        with gr.Row():
+          prompt = make_prompt('', visible=True)
+        model_info = gr.Markdown(label='Model Info')
+        with gr.Row():
+          model = make_model([])
+          samplers = [
+              'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)']
+          sampler = gr.Dropdown(
+              value=samplers[0], choices=samplers, label='Decoding'
+          )
+        with gr.Row():
+          run = gr.Button('Run', variant='primary')
+          clear = gr.Button('Clear')
+        highlighted_text = make_highlighted_text('', visible=False)
+
+    ##### UI logic.
+
+    def update_ui(model, prompt):
+      prompt = make_prompt(prompt, visible=True)
+      model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}'
+      return [prompt, model_info]
+
+    gr.on(
+        [model.change],
+        update_ui,
+        [model, prompt],
+        [prompt, model_info],
+    )
+
+    gr.on(
+        [run.click, prompt.submit],
+        compute,
+        [image, prompt, model, sampler],
+        [highlighted_text, image, annotated_image],
+    )
+    clear.click(
+        reset, None, [prompt, highlighted_text, image, annotated_image]
+    )
+
+    ##### Examples.
+
+    gr.set_static_paths(['examples/'])
+    all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')]
+    logging.info('loaded %d examples', len(all_examples))
+    example_image = gr.Image(
+        label='Image', visible=False)  # proxy, never visible
+    example_model = gr.Text(
+        label='Model', visible=False)  # proxy, never visible
+    example_prompt = gr.Text(
+        label='Prompt', visible=False)  # proxy, never visible
+    example_license = gr.Markdown(
+        label='Image License', visible=False)  # placeholder, never visible
+    gr.Examples(
+        examples=[
+            [
+                f'examples/{ex["name"]}.jpg',
+                ex['prompt'],
+                ex['model'],
+                ex['license'],
+            ]
+            for ex in all_examples
+            if ex['model'] in models.MODELS
+        ],
+        inputs=[example_image, example_prompt, example_model, example_license],
+    )
+
+    ##### Examples UI logic.
+
+    example_image.change(
+        lambda image_path: (
+            make_image(image_path, visible=True),
+            make_annotated_image(None, visible=False),
+            make_highlighted_text('', visible=False),
+        ),
+        example_image,
+        [image, annotated_image, highlighted_text],
+    )
+    def example_model_changed(model):
+      if model not in gradio_helpers.get_paths():
+        raise gr.Error(f'Model "{model}" not loaded!')
+      return model
+    example_model.change(example_model_changed, example_model, model)
+    example_prompt.change(make_prompt, example_prompt, prompt)
+
+    ##### Status.
+
+    status = gr.Markdown(f'Startup: {datetime.datetime.now()}')
+    gpu_kind = gr.Markdown(f'GPU=?')
+    demo.load(
+        lambda: [
+            gradio_helpers.get_status(),
+            make_model(list(gradio_helpers.get_paths())),
+        ],
+        None,
+        [status, model],
+    )
+    def get_gpu_kind():
+      device = jax.devices()[0]
+      if not gradio_helpers.should_mock() and device.platform != 'gpu':
+        raise gr.Error('GPU not visible to JAX!')
+      return f'GPU={device.device_kind}'
+    demo.load(get_gpu_kind, None, gpu_kind)
+
+  return demo
+
+
+if __name__ == '__main__':
+
+  logging.basicConfig(level=logging.INFO,
+                      format='%(asctime)s - %(levelname)s - %(message)s')
+
+  logging.info('JAX devices: %s', jax.devices())
+
+  for k, v in os.environ.items():
+    logging.info('environ["%s"] = %r', k, v)
+
+  gradio_helpers.set_warmup_function(warmup)
+  for name, (repo, filename, revision) in models.MODELS.items():
+    gradio_helpers.register_download(name, repo, filename, revision)
+
+  create_app().queue().launch()
diff --git a/examples/barsik.jpg b/examples/barsik.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..55f855f13e882e57272a4eed142c919e907b84b6
Binary files /dev/null and b/examples/barsik.jpg differ
diff --git a/examples/barsik.json b/examples/barsik.json
new file mode 100644
index 0000000000000000000000000000000000000000..6d6f13e76e15985824ab27135a8b62d8b278d0dc
--- /dev/null
+++ b/examples/barsik.json
@@ -0,0 +1,7 @@
+{
+  "name": "barsik",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "segment cat",
+  "license": "CC0 by [maximneumann@](https://github.com/maximneumann)"
+}
\ No newline at end of file
diff --git a/examples/biennale.jpg b/examples/biennale.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..05ba1292a74c2842df4b4341ecaf1a1c1ecbcce0
Binary files /dev/null and b/examples/biennale.jpg differ
diff --git a/examples/biennale.json b/examples/biennale.json
new file mode 100644
index 0000000000000000000000000000000000000000..532ff527f32ad4e5fa1ebd71ebacc14d537370e5
--- /dev/null
+++ b/examples/biennale.json
@@ -0,0 +1,7 @@
+{
+  "name": "biennale",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "In which city is this?",
+  "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)"
+}
\ No newline at end of file
diff --git a/examples/billard1.jpg b/examples/billard1.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2fbf3c5a9e96df8099640c4d9700fccac7063648
Binary files /dev/null and b/examples/billard1.jpg differ
diff --git a/examples/billard1.json b/examples/billard1.json
new file mode 100644
index 0000000000000000000000000000000000000000..2667d173894c20049779f493091cb00be8205d07
--- /dev/null
+++ b/examples/billard1.json
@@ -0,0 +1,7 @@
+{
+  "name": "billard1",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "How many red balls are there?",
+  "license": "CC0 by [mbosnjak@](https://github.com/mbosnjak)"
+}
\ No newline at end of file
diff --git a/examples/billard2.jpg b/examples/billard2.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a2a65c4b4c837082190bda4c1ec0d95aae757387
Binary files /dev/null and b/examples/billard2.jpg differ
diff --git a/examples/billard2.json b/examples/billard2.json
new file mode 100644
index 0000000000000000000000000000000000000000..1e66dd97b575f666c962436482fc18ee8682493e
--- /dev/null
+++ b/examples/billard2.json
@@ -0,0 +1,7 @@
+{
+  "name": "billard2",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "How many balls are there?",
+  "license": "CC0 by [mbosnjak@](https://github.com/mbosnjak)"
+}
\ No newline at end of file
diff --git a/examples/bowie.jpg b/examples/bowie.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a470c3fbcd2af4e81af9de46f6ba26f17db81631
Binary files /dev/null and b/examples/bowie.jpg differ
diff --git a/examples/bowie.json b/examples/bowie.json
new file mode 100644
index 0000000000000000000000000000000000000000..deb4dfd631631946765c9e90fa4555822a453e03
--- /dev/null
+++ b/examples/bowie.json
@@ -0,0 +1,7 @@
+{
+  "name": "bowie",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Who is this?",
+  "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)"
+}
\ No newline at end of file
diff --git a/examples/branch.jpg b/examples/branch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d95595728b845c0ec2b7ee508473541a48d84290
Binary files /dev/null and b/examples/branch.jpg differ
diff --git a/examples/branch.json b/examples/branch.json
new file mode 100644
index 0000000000000000000000000000000000000000..a86c14f5d3fe2f2d0512ce49fc2ab3b9b6012c61
--- /dev/null
+++ b/examples/branch.json
@@ -0,0 +1,7 @@
+{
+  "name": "branch",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "What caused this?",
+  "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)"
+}
\ No newline at end of file
diff --git a/examples/cc_fox.jpg b/examples/cc_fox.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..47c95d0a91241833574ccb19b8a355417a87bc7a
Binary files /dev/null and b/examples/cc_fox.jpg differ
diff --git a/examples/cc_fox.json b/examples/cc_fox.json
new file mode 100644
index 0000000000000000000000000000000000000000..69ee0678e50e701e0167097f4c41ed360f449aed
--- /dev/null
+++ b/examples/cc_fox.json
@@ -0,0 +1,7 @@
+{
+  "name": "cc_fox",
+  "comment": "",
+  "model": "paligemma-3b-mix-448",
+  "prompt": "Which breed is this fox?",
+  "license": "CC0 by [XiaohuaZhai@](https://sites.google.com/view/xzhai)"
+}
diff --git a/examples/cc_landscape.jpg b/examples/cc_landscape.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0a4b610c4234ffa67305b7952584430f4d953dde
Binary files /dev/null and b/examples/cc_landscape.jpg differ
diff --git a/examples/cc_landscape.json b/examples/cc_landscape.json
new file mode 100644
index 0000000000000000000000000000000000000000..c1a66ec2901cd5108c71a690f23cf6ef51a9fbee
--- /dev/null
+++ b/examples/cc_landscape.json
@@ -0,0 +1,7 @@
+{
+  "name": "cc_landscape",
+  "comment": "",
+  "model": "paligemma-3b-mix-448",
+  "prompt": "What does the image show?",
+  "license": "CC0 by [XiaohuaZhai@](https://sites.google.com/view/xzhai)"
+}
diff --git a/examples/cc_puffin.jpg b/examples/cc_puffin.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ae6bb3dc676dc40b2854c58d8238ff7770f3072d
Binary files /dev/null and b/examples/cc_puffin.jpg differ
diff --git a/examples/cc_puffin.json b/examples/cc_puffin.json
new file mode 100644
index 0000000000000000000000000000000000000000..3ca086360d6c168cd3587a189aebe7a7bae2ca41
--- /dev/null
+++ b/examples/cc_puffin.json
@@ -0,0 +1,7 @@
+{
+  "name": "cc_puffin",
+  "comment": "",
+  "model": "paligemma-3b-mix-448",
+  "prompt": "detect puffin in the back ; puffin in front",
+  "license": "CC0 by [XiaohuaZhai@](https://sites.google.com/view/xzhai)"
+}
diff --git a/examples/couch.jpg b/examples/couch.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..81800961f0498e46fc06ef525311f0a5d88eb4cb
Binary files /dev/null and b/examples/couch.jpg differ
diff --git a/examples/couch.json b/examples/couch.json
new file mode 100644
index 0000000000000000000000000000000000000000..32f4cba01ded6e629661c4f81ec9125f9af8409e
--- /dev/null
+++ b/examples/couch.json
@@ -0,0 +1,7 @@
+{
+  "name": "couch",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "How many yellow cushions are on the couch?",
+  "license": "CC0"
+}
\ No newline at end of file
diff --git a/examples/couch_.json b/examples/couch_.json
new file mode 100644
index 0000000000000000000000000000000000000000..22a288af099703296a1208279484354f88ed5c20
--- /dev/null
+++ b/examples/couch_.json
@@ -0,0 +1,7 @@
+{
+  "name": "couch",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "How many painting do you see in the image?",
+  "license": "CC0"
+}
\ No newline at end of file
diff --git a/examples/cups.jpg b/examples/cups.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..29fb745612887e7a0d4137a503831fe4dd0841d1
Binary files /dev/null and b/examples/cups.jpg differ
diff --git a/examples/cups.json b/examples/cups.json
new file mode 100644
index 0000000000000000000000000000000000000000..078e3df2986f38350c30eaf2e1e1522a842b7664
--- /dev/null
+++ b/examples/cups.json
@@ -0,0 +1,7 @@
+{
+  "name": "cups",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "how many cups?",
+  "license": "CC0 by [mbosnjak@](https://github.com/mbosnjak)"
+}
\ No newline at end of file
diff --git a/examples/dice.jpg b/examples/dice.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..76d0fbabee3a9aa31d3335a850f25b0c40952d70
Binary files /dev/null and b/examples/dice.jpg differ
diff --git a/examples/dice.json b/examples/dice.json
new file mode 100644
index 0000000000000000000000000000000000000000..a3fb3f9703dd6ea055569fba49b4a96b76df8235
--- /dev/null
+++ b/examples/dice.json
@@ -0,0 +1,7 @@
+{
+  "name": "dice",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "segment dice ; dice",
+  "license": "CC0 by [andresusanopinto@](https://github.com/andresusanopinto)"
+}
\ No newline at end of file
diff --git a/examples/emu.jpg b/examples/emu.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d298e271a108a89cb89473af26bd202902f8b901
Binary files /dev/null and b/examples/emu.jpg differ
diff --git a/examples/emu.json b/examples/emu.json
new file mode 100644
index 0000000000000000000000000000000000000000..23532eac207641e3d138ceb67f9a051d6d231539
--- /dev/null
+++ b/examples/emu.json
@@ -0,0 +1,7 @@
+{
+  "name": "emu",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "What animal is this?",
+  "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)"
+}
\ No newline at end of file
diff --git a/examples/fridge.jpg b/examples/fridge.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..dd6af3f32f8b3bad650b2162f1b4628d8e5a26db
Binary files /dev/null and b/examples/fridge.jpg differ
diff --git a/examples/fridge.json b/examples/fridge.json
new file mode 100644
index 0000000000000000000000000000000000000000..c6628d78020b331530c4c6a3726c38c454c4da2f
--- /dev/null
+++ b/examples/fridge.json
@@ -0,0 +1,7 @@
+{
+  "name": "fridge",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Describe the image.",
+  "license": "CC0 by [andresusanopinto@](https://github.com/andresusanopinto)"
+}
\ No newline at end of file
diff --git a/examples/givt.jpg b/examples/givt.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b269c132464ecbd8c7ce8f5464cb6ebf142cc8a1
Binary files /dev/null and b/examples/givt.jpg differ
diff --git a/examples/givt.json b/examples/givt.json
new file mode 100644
index 0000000000000000000000000000000000000000..4e244d55bd0423fd8041accf1ba3d9bb43d494af
--- /dev/null
+++ b/examples/givt.json
@@ -0,0 +1,7 @@
+{
+  "name": "givt",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "What does the image show?",
+  "license": "CC-BY [GIVT paper](https://arxiv.org/abs/2312.02116)"
+}
\ No newline at end of file
diff --git a/examples/greenlake.jpg b/examples/greenlake.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..65401579082eebb41a70869d8785b2c84a437476
Binary files /dev/null and b/examples/greenlake.jpg differ
diff --git a/examples/greenlake.json b/examples/greenlake.json
new file mode 100644
index 0000000000000000000000000000000000000000..5de5282b9608ada567cd696e8e4846e0906088da
--- /dev/null
+++ b/examples/greenlake.json
@@ -0,0 +1,7 @@
+{
+  "name": "greenlake",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Describe the image.",
+  "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)"
+}
\ No newline at end of file
diff --git a/examples/howto.jpg b/examples/howto.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..5f079c6751730ab16008835765e6296c8fcc2d8c
Binary files /dev/null and b/examples/howto.jpg differ
diff --git a/examples/howto.json b/examples/howto.json
new file mode 100644
index 0000000000000000000000000000000000000000..2b44aae0878af6ff9abf1628ccedbb932179d19d
--- /dev/null
+++ b/examples/howto.json
@@ -0,0 +1,7 @@
+{
+  "name": "howto",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "What does this image show?",
+  "license": "CC-BY [How to train your ViT?](https://arxiv.org/abs/2106.10270)"
+}
\ No newline at end of file
diff --git a/examples/markers.jpg b/examples/markers.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..756537b93cf074ebfd32a45a7438b46914d335c3
Binary files /dev/null and b/examples/markers.jpg differ
diff --git a/examples/markers.json b/examples/markers.json
new file mode 100644
index 0000000000000000000000000000000000000000..9093a2c9a468dd7995039cae415394f182c35e89
--- /dev/null
+++ b/examples/markers.json
@@ -0,0 +1,7 @@
+{
+  "name": "markers",
+  "comment": "answer en How many cups are there?",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "How many cups are there?",
+  "license": "CC0"
+}
\ No newline at end of file
diff --git a/examples/mcair.jpg b/examples/mcair.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e965dde07d103114690bfc086d24ab3fdb054e65
Binary files /dev/null and b/examples/mcair.jpg differ
diff --git a/examples/mcair.json b/examples/mcair.json
new file mode 100644
index 0000000000000000000000000000000000000000..0f50b7f96253821cb70eb5aac40760d140252ffa
--- /dev/null
+++ b/examples/mcair.json
@@ -0,0 +1,7 @@
+{
+  "name": "mcair",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Can you board this airplane?",
+  "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)"
+}
\ No newline at end of file
diff --git a/examples/mcair_.json b/examples/mcair_.json
new file mode 100644
index 0000000000000000000000000000000000000000..7ae3353a0ed5109a9cc9f26dc179ec7a86357b8c
--- /dev/null
+++ b/examples/mcair_.json
@@ -0,0 +1,7 @@
+{
+  "name": "mcair",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Is this a restaurant?",
+  "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)"
+}
\ No newline at end of file
diff --git a/examples/minergie.jpg b/examples/minergie.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8372f189fa3549042e752931b2577fc7384be0fc
Binary files /dev/null and b/examples/minergie.jpg differ
diff --git a/examples/minergie.json b/examples/minergie.json
new file mode 100644
index 0000000000000000000000000000000000000000..cb292ed5e6eafc30ccecb9d7b2569ee370ab3b06
--- /dev/null
+++ b/examples/minergie.json
@@ -0,0 +1,7 @@
+{
+  "name": "minergie",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "ocr",
+  "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)"
+}
\ No newline at end of file
diff --git a/examples/morel.jpg b/examples/morel.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e1498f0c98d6a2a7187ec74b809ca9f8fdc776c2
Binary files /dev/null and b/examples/morel.jpg differ
diff --git a/examples/morel.json b/examples/morel.json
new file mode 100644
index 0000000000000000000000000000000000000000..c4fb09a89a268cae5cbdff0810feea34177da7c8
--- /dev/null
+++ b/examples/morel.json
@@ -0,0 +1,7 @@
+{
+  "name": "morel",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "detect morel",
+  "license": "CC0 by [andsteing@](https://huggingface.co/andsteing)"
+}
\ No newline at end of file
diff --git a/examples/motorcyclists.jpg b/examples/motorcyclists.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..91fdffa020ea0e2ba5ef1a9be7dd68bdb7a081ce
Binary files /dev/null and b/examples/motorcyclists.jpg differ
diff --git a/examples/motorcyclists.json b/examples/motorcyclists.json
new file mode 100644
index 0000000000000000000000000000000000000000..f4a0d1e8b7207ac55ed90d09cbfc68ea065c901c
--- /dev/null
+++ b/examples/motorcyclists.json
@@ -0,0 +1,7 @@
+{
+  "name": "motorcyclists",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "What does the image show?",
+  "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)"
+}
\ No newline at end of file
diff --git a/examples/parking.jpg b/examples/parking.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3b3c6d3ebb5057b228f04a85a83184c5a1c8aaba
Binary files /dev/null and b/examples/parking.jpg differ
diff --git a/examples/parking.json b/examples/parking.json
new file mode 100644
index 0000000000000000000000000000000000000000..9964ba3acfadec3e0165377bb182ec416672b49a
--- /dev/null
+++ b/examples/parking.json
@@ -0,0 +1,7 @@
+{
+  "name": "parking",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Describe the image.",
+  "license": "CC0 by [xiaohuazhai@](https://huggingface.co/xiaohuazhai)"
+}
\ No newline at end of file
diff --git a/examples/password.jpg b/examples/password.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c7804dfa42aa3cad23089087bffca604a8c3507e
Binary files /dev/null and b/examples/password.jpg differ
diff --git a/examples/password.json b/examples/password.json
new file mode 100644
index 0000000000000000000000000000000000000000..070f3f8c992a948b177f2cc647a7f9260c0d6c38
--- /dev/null
+++ b/examples/password.json
@@ -0,0 +1,7 @@
+{
+  "name": "password",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "What is the password?",
+  "license": "CC0 by [akolesnikoff@](https://github.com/akolesnikoff)"
+}
\ No newline at end of file
diff --git a/examples/preservationhall.jpg b/examples/preservationhall.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..adab242566923b4fa3ac97b179c358b07697f221
Binary files /dev/null and b/examples/preservationhall.jpg differ
diff --git a/examples/preservationhall.json b/examples/preservationhall.json
new file mode 100644
index 0000000000000000000000000000000000000000..6f9be7e1169ea9269ee0e9ebe58d32b424b4c21f
--- /dev/null
+++ b/examples/preservationhall.json
@@ -0,0 +1,7 @@
+{
+  "name": "preservationhall",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Describe the image.",
+  "license": "CC0 by [mitscha@](https://github.com/mitscha)"
+}
\ No newline at end of file
diff --git a/examples/preservationhall_.json b/examples/preservationhall_.json
new file mode 100644
index 0000000000000000000000000000000000000000..5571c5272f91d36f2d67a55aac2d61715e3e5f26
--- /dev/null
+++ b/examples/preservationhall_.json
@@ -0,0 +1,7 @@
+{
+  "name": "preservationhall",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "What's the name of the place?",
+  "license": "CC0 by [mitscha@](https://github.com/mitscha)"
+}
\ No newline at end of file
diff --git a/examples/ulges.jpg b/examples/ulges.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e91d86083e3dbfecc2688735d380a441c0dee227
Binary files /dev/null and b/examples/ulges.jpg differ
diff --git a/examples/ulges.json b/examples/ulges.json
new file mode 100644
index 0000000000000000000000000000000000000000..d22ee5806c716238cd83fcbad2597dcf31dc6e04
--- /dev/null
+++ b/examples/ulges.json
@@ -0,0 +1,7 @@
+{
+  "name": "ulges",
+  "comment": "",
+  "model": "paligemma-3b-mix-224",
+  "prompt": "Who is the author of this book?",
+  "license": "CC0"
+}
\ No newline at end of file
diff --git a/gradio_helpers.py b/gradio_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1de6dc7e34bf56bdb0ba4b451208638a11868879
--- /dev/null
+++ b/gradio_helpers.py
@@ -0,0 +1,280 @@
+"""Gradio helpers for caching, downloading etc."""
+
+import concurrent.futures
+import contextlib
+import datetime
+import functools
+import logging
+import os
+import shutil
+import subprocess
+import sys
+import tempfile
+import threading
+import time
+import unittest.mock
+
+import huggingface_hub
+import jax
+import numpy as np
+import psutil
+
+
+def _clone_git(url, destination_folder, commit_hash=None):
+  subprocess.run([
+      'git', 'clone', '--depth=1',
+      url, destination_folder
+  ], check=True)
+  if commit_hash:
+    subprocess.run(
+        ['git', '-C', destination_folder, 'checkout', commit_hash], check=True
+    )
+
+
+def setup():
+  """Installs big_vision repo and mocks tensorflow_text."""
+  for url, dst_name, commit_hash in (
+      (
+          'https://github.com/google-research/big_vision',
+          'big_vision_repo',
+          None,
+      ),
+  ):
+    dst_path = os.path.join(tempfile.gettempdir(), dst_name)
+    if os.path.exists(dst_path):
+      print('Found existing "%s" at "%s"' % (url, dst_path))
+    else:
+      print('Cloning "%s" into "%s"' % (url, dst_path))
+      _clone_git(url, dst_path, commit_hash)
+    
+    if dst_path not in sys.path:
+      sys.path.insert(0, dst_path)
+
+  # Imported in `big_vision.pp.ops_text` but we don't use it.
+  sys.modules['tensorflow_text'] = unittest.mock.MagicMock()
+
+
+# Must be run in main app before other BV imports:
+setup()
+
+
+def should_mock():
+  """Returns `True` if `MOCK_MODEL=yes` is set in environment."""
+  return os.environ.get('MOCK_MODEL') == 'yes'
+
+
+@contextlib.contextmanager
+def timed(name, start_message=False):
+  """Emits "Timed {name}: .1f secs" message to INFO logs."""
+  t0 = time.monotonic()
+  timing = dict(dt=None)
+  try:
+    if start_message:
+      logging.info('Timing %s...', name)
+    yield timing
+  finally:
+    timing['secs'] = time.monotonic() - t0
+    logging.info('Timed %s: %.1f secs', name, timing['secs'])
+
+
+def synced(f):
+  """Syncs calls to `f` with a `threading.Lock()`."""
+  lock = threading.Lock()
+  @functools.wraps(f)
+  def wrapper(*args, **kw):
+    t0 = time.monotonic()
+    with lock:
+      lock_dt = time.monotonic() - t0
+      logging.info('synced wait: %.1f secs', lock_dt)
+      return f(*args, **kw)
+  return wrapper
+
+
+_warmed_up = set()
+_warmup_function = None
+
+
+def set_warmup_function(warmup_function):
+  global _warmup_function
+  _warmup_function = warmup_function
+
+
+_lock = threading.Lock()
+_scheduled = {}
+_download_secs = 0
+_warmup_secs = 0
+_loading_secs = 0
+_done = {}
+_failed = {}
+
+
+def _do_download():
+  """Downloading files, to be started in background thread."""
+  global _download_secs
+  executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
+  while True:
+    if not _scheduled:
+      time.sleep(1)
+      continue
+
+    name, (repo, filename, revision) = next(iter(_scheduled.items()))
+    logging.info('Downloading "%s" %s/%s/%s...', name, repo, filename, revision)
+    with timed(f'downloading {name}', True) as t:
+      if should_mock():
+        logging.warning('Mocking loading')
+        time.sleep(10.)
+        _done[name] = None
+      else:
+        try:
+          _done[name] = huggingface_hub.hf_hub_download(
+              repo_id=repo, filename=filename, revision=revision)
+        except Exception as e:  # pylint: disable=broad-exception-caught
+          logging.exception('Could not download "%s" from hub!', name)
+          _failed[name] = str(e)
+          with _lock:
+            _scheduled.pop(name)
+          continue
+
+    if _warmup_function:
+      def warmup(name):
+        global _warmup_secs
+        with timed(f'warming up {name}', True) as t:
+          try:
+            _warmup_function(name)
+            _warmed_up.add(name)
+          except Exception:  # pylint: disable=broad-exception-caught
+            logging.exception('Could not warmup "%s"!', name)
+        _warmup_secs += t['secs']
+      executor.submit(warmup, name)
+
+    _download_secs += t['secs']
+    with _lock:
+      _scheduled.pop(name)
+
+
+def register_download(name, repo, filename, revision='main'):
+  """Will cause download of `filename` from HF `repo` in background thread."""
+  with _lock:
+    if name not in _scheduled:
+      _scheduled[name] = (repo, filename, revision)
+
+
+def _hms(secs):
+  """Formats `secs=3700` to `"01:01:40"`."""
+  secs = int(secs)
+  h = secs // 3600
+  m = (secs - h * 3600) // 60
+  s = secs % 60
+  return (f'{h}:' if h else '') + f'{m:02}:{s:02}'
+
+
+def downloads_status():
+  """Returns string representation of download stats."""
+  done_t = remaining_t = ''
+  if _done:
+    done_t = f' in {_hms(_download_secs)}'
+    remaining_t = f' {_hms(_download_secs/len(_done)*len(_scheduled))}'
+  status = f'Downloaded {len(_done)}{done_t}'
+  if _scheduled:
+    status += f', {len(_scheduled)}{remaining_t} remaining'
+  if _warmup_function:
+    status += f', warmed up {len(_warmed_up)} in {_hms(_warmup_secs)}'
+  if _failed:
+    status += f', {len(_failed)} failed'
+  return status
+
+
+def get_paths():
+  """Returns dictionary `name` to `path` from previous `register_download()`."""
+  return dict(_done)
+
+
+_download_thread = threading.Thread(target=_do_download)
+_download_thread.daemon = True
+_download_thread.start()
+
+
+_estimated_real = [(10, 10)]
+_memory_cache = {}
+
+
+def get_with_progress(getter, secs, progress, step=0.1):
+  """Returns result from `getter` while showing a progress bar."""
+  if progress is None:
+    return getter()
+  with concurrent.futures.ThreadPoolExecutor() as executor:
+    future = executor.submit(getter)
+    for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
+      if not future.done():
+        time.sleep(step)
+  return future.result()
+
+
+def _get_array_sizes(tree):
+  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
+
+
+def get_memory_cache(
+    key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
+):
+  """Keeps cache below specified size by removing elements not last accessed."""
+  if key in _memory_cache:
+    _memory_cache[key] = _memory_cache.pop(key)  # Updates "last accessed" order
+    return _memory_cache[key]
+
+  est, real = zip(*_estimated_real)
+  if estimated_secs is None:
+    estimated_secs = sum(est) / len(est)
+  with timed(f'loading {key}') as t:
+    estimated_secs *= sum(real) / sum(est)
+    value = get_with_progress(getter, estimated_secs, progress)
+  _estimated_real.append((estimated_secs, t['secs']))
+
+  if not max_cache_size_bytes:
+    return value
+
+  _memory_cache[key] = value
+  sz = sum(_get_array_sizes(list(_memory_cache.values())))
+  logging.info('New memory cache size=%.1f MB', sz/1e6)
+
+  while sz > max_cache_size_bytes:
+    k, v = next(iter(_memory_cache.items()))
+    if k == key:
+      break
+    s = sum(_get_array_sizes(v))
+    logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
+    _memory_cache.pop(k)
+    sz -= s
+
+  return value
+
+
+def get_memory_cache_info():
+  """Returns number of items and total size in bytes."""
+  sizes = _get_array_sizes(_memory_cache)
+  return len(_memory_cache), sum(sizes)
+
+
+def get_system_info():
+  """Returns string describing system's RAM/disk status."""
+  host_colocation = int(os.environ.get('HOST_COLOCATION', '1'))
+  vm = psutil.virtual_memory()
+  du = shutil.disk_usage('.')
+  return (
+      f'RAM {vm.used / 1e9:.1f}/{vm.total / host_colocation / 1e9:.1f}G, '
+      f'disk {du.used / 1e9:.1f}/{du.total / host_colocation / 1e9:.1f}G'
+  )
+
+
+def get_status(include_system_info=True):
+  """Returns string about download/memory/system status."""
+  mc_len, mc_sz = get_memory_cache_info()
+  mc_t = _hms(sum(real for _, real in _estimated_real[1:]))
+  return (
+      'Timestamp: '
+      + datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
+      + ' – Model stats: '
+      + downloads_status()
+      + ', ' + f'memory-cached {mc_len} ({mc_sz/1e9:.1f}G) in {mc_t}' +
+      (' – System: ' + get_system_info() if include_system_info else '')
+  )
diff --git a/models.py b/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2b588c911c94878398ef6636431afa803443785
--- /dev/null
+++ b/models.py
@@ -0,0 +1,87 @@
+"""Model-related code and constants."""
+
+import dataclasses
+import os
+import re
+
+import PIL.Image
+
+# pylint: disable=g-bad-import-order
+import gradio_helpers
+import paligemma_bv
+
+
+ORGANIZATION = 'google'
+BASE_MODELS = [
+    ('paligemma-3b-mix-224-jax', 'paligemma-3b-mix-224'),
+    ('paligemma-3b-mix-448-jax', 'paligemma-3b-mix-448'),
+]
+MODELS = {
+    **{
+        model_name: (
+            f'{ORGANIZATION}/{repo}',
+            f'{model_name}.bf16.npz',
+            'bfloat16',  # Model repo revision.
+        )
+        for repo, model_name in BASE_MODELS
+    },
+}
+
+MODELS_INFO = {
+    'paligemma-3b-mix-224': (
+        'JAX/FLAX PaliGemma 3B weights, finetuned with 224x224 input images and 256 token input/output '
+        'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
+        'bfloat16 and float16 format for research purposes only.'
+    ),
+    'paligemma-3b-mix-448': (
+        'JAX/FLAX PaliGemma 3B weights, finetuned with 448x448 input images and 512 token input/output '
+        'text sequences on a mixture of downstream academic datasets. The models are available in float32, '
+        'bfloat16 and float16 format for research purposes only.'
+    ),
+}
+
+MODELS_RES_SEQ = {
+    'paligemma-3b-mix-224': (224, 256),
+    'paligemma-3b-mix-448': (448, 512),
+}
+
+# "CPU basic" has 16G RAM, "T4 small" has 15 GB RAM.
+# Below value should be smaller than "available RAM - one model".
+# A single bf16 is about 5860 MB.
+MAX_RAM_CACHE = int(float(os.environ.get('RAM_CACHE_GB', '0')) * 1e9)
+
+config = paligemma_bv.PaligemmaConfig(
+    ckpt='',  # will be set below
+    res=224,
+    text_len=64,
+    tokenizer='gemma(tokensets=("loc", "seg"))',
+    vocab_size=256_000 + 1024 + 128,
+)
+
+
+def get_cached_model(
+    model_name: str,
+) -> tuple[paligemma_bv.PaliGemmaModel, paligemma_bv.ParamsCpu]:
+  """Returns model and params, using RAM cache."""
+  res, seq = MODELS_RES_SEQ[model_name]
+  model_path = gradio_helpers.get_paths()[model_name]
+  config_ = dataclasses.replace(config, ckpt=model_path, res=res, text_len=seq)
+  model, params_cpu = gradio_helpers.get_memory_cache(
+      config_,
+      lambda: paligemma_bv.load_model(config_),
+      max_cache_size_bytes=MAX_RAM_CACHE,
+  )
+  return model, params_cpu
+
+
+def generate(
+    model_name: str, sampler: str, image: PIL.Image.Image, prompt: str
+) -> str:
+  """Generates output with specified `model_name`, `sampler`."""
+  model, params_cpu = get_cached_model(model_name)
+  batch = model.shard_batch(model.prepare_batch([image], [prompt]))
+  with gradio_helpers.timed('sharding'):
+    params = model.shard_params(params_cpu)
+  with gradio_helpers.timed('computation', start_message=True):
+    tokens = model.predict(params, batch, sampler=sampler)
+  return model.tokenizer.to_str(tokens[0])
diff --git a/paligemma_bv.py b/paligemma_bv.py
new file mode 100644
index 0000000000000000000000000000000000000000..b70512134711c80a387f6a816cec7a64842a8dd9
--- /dev/null
+++ b/paligemma_bv.py
@@ -0,0 +1,207 @@
+"""Wraps `big_vision` PaliGemma model for easy use in demo."""
+
+from collections.abc import Callable
+import dataclasses
+from typing import Any
+
+import jax
+import jax.numpy as jnp
+import ml_collections
+import numpy as np
+import PIL.Image
+
+from big_vision import sharding
+from big_vision import utils
+from big_vision.models.proj.paligemma import paligemma
+from big_vision.pp import builder as pp_builder
+from big_vision.pp import ops_general  # pylint: disable=unused-import
+from big_vision.pp import ops_image  # pylint: disable=unused-import
+from big_vision.pp import ops_text  # pylint: disable=unused-import
+from big_vision.pp import tokenizer
+from big_vision.pp.proj.paligemma import ops as ops_paligemma  # pylint: disable=unused-import
+from big_vision.trainers.proj.paligemma import predict_fns
+
+
+mesh = jax.sharding.Mesh(jax.devices(), 'data')
+
+
+def _recover_bf16(x):
+  if x.dtype == np.dtype('V2'):
+    x = x.view('bfloat16')
+  return x
+
+
+def _load(
+    path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152
+):
+  """Loads model, params, decode functions and tokenizer."""
+  tok = tokenizer.get_tokenizer(tokenizer_spec)
+
+  config = ml_collections.FrozenConfigDict(dict(
+      llm_model='proj.paligemma.gemma_bv',
+      llm=dict(vocab_size=vocab_size, variant='gemma_2b'),
+      img=dict(variant='So400m/14', pool_type='none', scan=True),
+  ))
+  model = paligemma.Model(**config)
+  decode = predict_fns.get_all(model)['decode']
+  beam_decode = predict_fns.get_all(model)['beam_decode']
+
+  params_cpu = paligemma.load(None, path, config)
+  # Some numpy versions don't load bfloat16 correctly:
+  params_cpu = jax.tree.map(_recover_bf16, params_cpu)
+
+  return model, params_cpu, decode, beam_decode, tok
+
+
+def _shard_params(params_cpu):
+  """Shards `params_cpu` with fsdp strategy on all available devices."""
+  params_sharding = sharding.infer_sharding(
+      params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh
+  )
+  params = jax.tree.map(utils.reshard, params_cpu, params_sharding)
+  return params
+
+
+def _pil2np(img):
+  """Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`."""
+  if isinstance(img, PIL.Image.Image):
+    img = np.array(img)
+    img = img[..., :3]
+    if img.ndim == 2:
+      img = img[..., None]
+    if img.shape[-1] == 1:
+      img = np.repeat(img, 3, axis=-1)
+  return img
+
+
+def _prepare_batch(
+    images,
+    prefixes,
+    *,
+    res=224,
+    tokenizer_spec='gemma(tokensets=("loc", "seg"))',
+    suffixes=None,
+    text_len=64,
+):
+  """Returns non-sharded batch."""
+
+  pp_fn = pp_builder.get_preprocess_fn('|'.join([
+      f'resize({res}, antialias=True)|value_range(-1, 1)',
+      f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')",
+      f"tok(key='septok', text='\\n', model='{tokenizer_spec}')",
+      f"tok(key='suffix', model='{tokenizer_spec}')",
+      'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])',  # pylint: disable=line-too-long
+      f'tolen({text_len}, pad_value=0, key="text")',
+      f'tolen({text_len}, pad_value=1, key="mask_ar")',
+      f'tolen({text_len}, pad_value=0, key="mask_input")',
+      'keep("image", "text", "mask_ar", "mask_input")',
+  ]), log_data=False)
+  assert not isinstance(prefixes, str), f'expected batch: {prefixes}'
+  assert (
+      isinstance(images, (list, tuple)) or images.ndim == 4
+  ), f'expected batch: {images.shape}'
+  if suffixes is None:
+    suffixes = [''] * len(prefixes)
+  assert len(prefixes) == len(suffixes) == len(images)
+  examples = [{'_mask': True, **pp_fn({
+      'image': np.asarray(_pil2np(image)),
+      'prefix': np.array(prefix),
+      'suffix': np.array(suffix),
+  })} for image, prefix, suffix in zip(images, prefixes, suffixes)]
+  batch = jax.tree_map(lambda *xs: np.stack(xs), *examples)
+  return batch
+
+
+def _shard_batch(batch, n=None):
+  """Shards `batch` with fsdp strategy on all available devices."""
+  if n is None:
+    n = jax.local_device_count()
+  def pad(x):
+    return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1))
+  batch = {k: pad(v) for k, v in batch.items()}
+  data_sharding = jax.sharding.NamedSharding(
+      mesh, jax.sharding.PartitionSpec('data')
+  )
+  batch_on_device = utils.reshard(batch, data_sharding)
+  return batch_on_device
+
+
+@dataclasses.dataclass(frozen=True, kw_only=True, order=True)
+class PaligemmaConfig:
+  """Desribes a `big_vision` PaliGemma model."""
+
+  ckpt: str
+  res: int
+  text_len: int
+  tokenizer: str
+  vocab_size: int
+
+
+@dataclasses.dataclass(frozen=True, kw_only=True)
+class PaliGemmaModel:
+  """Wraps a `big_vision` PaliGemma model."""
+
+  config: PaligemmaConfig
+  tokenizer: tokenizer.Tokenizer
+  decode: Callable[..., Any]
+  beam_decode: Callable[..., Any]
+
+  @classmethod
+  def shard_batch(cls, batch):
+    return _shard_batch(batch)
+
+  @classmethod
+  def shard_params(cls, params_cpu):
+    return _shard_params(params_cpu)
+
+  def prepare_batch(self, images, texts, suffixes=None):
+    return _prepare_batch(
+        images=images,
+        prefixes=texts,
+        suffixes=suffixes,
+        res=self.config.res,
+        tokenizer_spec=self.config.tokenizer,
+        text_len=self.config.text_len,
+    )
+
+  def predict(
+      self,
+      params,
+      batch,
+      devices=None,
+      max_decode_len=128,
+      sampler='greedy',
+      **kw,
+  ):
+    """Returns tokens."""
+    if devices is None:
+      devices = jax.devices()
+    if sampler == 'beam':
+      decode = self.beam_decode
+    else:
+      decode = self.decode
+      kw['sampler'] = sampler
+    return decode(
+        {'params': params},
+        batch=batch,
+        devices=devices,
+        eos_token=self.tokenizer.eos_token,
+        max_decode_len=max_decode_len,
+        **kw,
+    )
+
+
+ParamsCpu = Any
+
+
+def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]:
+  """Loads model from config."""
+  model, params_cpu, decode, beam_decode, tok = _load(
+      path=config.ckpt,
+      tokenizer_spec=config.tokenizer,
+      vocab_size=config.vocab_size,
+  )
+  del model
+  return PaliGemmaModel(
+      config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode,
+  ), params_cpu
diff --git a/paligemma_parse.py b/paligemma_parse.py
new file mode 100644
index 0000000000000000000000000000000000000000..38356f2103edfe88852402de53948ce220c1b665
--- /dev/null
+++ b/paligemma_parse.py
@@ -0,0 +1,184 @@
+"""Parses PaliGemma output."""
+
+import functools
+import re
+
+import flax.linen as nn
+import jax
+import jax.numpy as jnp
+import numpy as np
+import PIL.Image
+
+
+EXAMPLE_STRING = '<loc0000><loc0000><loc0930><loc1012> <seg114><seg074><seg106><seg044><seg030><seg027><seg119><seg119><seg120><seg117><seg082><seg082><seg051><seg005><seg125><seg097> wall ; <loc0722><loc0047><loc0895><loc0378> <seg068><seg114><seg014><seg037><seg029><seg063><seg048><seg104><seg010><seg056><seg021><seg056><seg019><seg017><seg102><seg121> car ; <loc0180><loc0596><loc0782><loc0961> <seg026><seg028><seg028><seg026><seg104><seg026><seg029><seg022><seg000><seg068><seg092><seg125><seg003><seg127><seg121><seg043> david bowie ; <loc0234><loc0043><loc0736><loc0289> <seg068><seg008><seg091><seg064><seg007><seg055><seg017><seg090><seg042><seg052><seg068><seg086><seg001><seg014><seg093><seg052> david bowie ; <loc0230><loc0300><loc0736><loc0499> <seg073><seg011><seg114><seg059><seg048><seg097><seg091><seg022><seg007><seg036><seg091><seg022><seg016><seg009><seg003><seg036> david bowie'  # pylint: disable=line-too-long
+
+_MODEL_PATH = 'vae-oid.npz'
+
+_SEGMENT_DETECT_RE = re.compile(
+    r'(.*?)' +
+    r'<loc(\d{4})>' * 4 + r'\s*' +
+    '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
+    r'\s*([^;<>]+)? ?(?:; )?',
+)
+
+
+def _get_params(checkpoint):
+  """Converts PyTorch checkpoint to Flax params."""
+
+  def transp(kernel):
+    return np.transpose(kernel, (2, 3, 1, 0))
+
+  def conv(name):
+    return {
+        'bias': checkpoint[name + '.bias'],
+        'kernel': transp(checkpoint[name + '.weight']),
+    }
+
+  def resblock(name):
+    return {
+        'Conv_0': conv(name + '.0'),
+        'Conv_1': conv(name + '.2'),
+        'Conv_2': conv(name + '.4'),
+    }
+
+  return {
+      '_embeddings': checkpoint['_vq_vae._embedding'],
+      'Conv_0': conv('decoder.0'),
+      'ResBlock_0': resblock('decoder.2.net'),
+      'ResBlock_1': resblock('decoder.3.net'),
+      'ConvTranspose_0': conv('decoder.4'),
+      'ConvTranspose_1': conv('decoder.6'),
+      'ConvTranspose_2': conv('decoder.8'),
+      'ConvTranspose_3': conv('decoder.10'),
+      'Conv_1': conv('decoder.12'),
+  }
+
+
+def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
+  batch_size, num_tokens = codebook_indices.shape
+  assert num_tokens == 16, codebook_indices.shape
+  unused_num_embeddings, embedding_dim = embeddings.shape
+
+  encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
+  encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
+  return encodings
+
+
+@functools.cache
+def _get_reconstruct_masks():
+  """Reconstructs masks from codebook indices.
+
+  Returns:
+    A function that expects indices shaped `[B, 16]` of dtype int32, each
+    ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
+    `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
+  """
+
+  class ResBlock(nn.Module):
+    features: int
+
+    @nn.compact
+    def __call__(self, x):
+      original_x = x
+      x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
+      x = nn.relu(x)
+      x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
+      x = nn.relu(x)
+      x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
+      return x + original_x
+
+  class Decoder(nn.Module):
+    """Upscales quantized vectors to mask."""
+
+    @nn.compact
+    def __call__(self, x):
+      num_res_blocks = 2
+      dim = 128
+      num_upsample_layers = 4
+
+      x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
+      x = nn.relu(x)
+
+      for _ in range(num_res_blocks):
+        x = ResBlock(features=dim)(x)
+
+      for _ in range(num_upsample_layers):
+        x = nn.ConvTranspose(
+            features=dim,
+            kernel_size=(4, 4),
+            strides=(2, 2),
+            padding=2,
+            transpose_kernel=True,
+        )(x)
+        x = nn.relu(x)
+        dim //= 2
+
+      x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
+
+      return x
+
+  def reconstruct_masks(codebook_indices):
+    quantized = _quantized_values_from_codebook_indices(
+        codebook_indices, params['_embeddings']
+    )
+    return Decoder().apply({'params': params}, quantized)
+
+  with open(_MODEL_PATH, 'rb') as f:
+    params = _get_params(dict(np.load(f)))
+
+  return jax.jit(reconstruct_masks, backend='cpu')
+
+
+def extract_objs(text, width, height, unique_labels=False):
+  """Returns objs for a string with "<loc>" and "<seg>" tokens."""
+  objs = []
+  seen = set()
+  while text:
+    m = _SEGMENT_DETECT_RE.match(text)
+    if not m:
+      break
+
+    gs = list(m.groups())
+    before = gs.pop(0)
+    name = gs.pop()
+    y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
+    y1, x1, y2, x2 = map(round, (y1*height, x1*width, y2*height, x2*width))
+
+    seg_indices = gs[4:20]
+    if seg_indices[0] is None:
+      mask = None
+    else:
+      seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
+      m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
+      m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
+      m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
+      mask = np.zeros([height, width])
+      if y2 > y1 and x2 > x1:
+        mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
+
+    content = m.group()
+    if before:
+      objs.append(dict(content=before))
+      content = content[len(before):]
+    while unique_labels and name in seen:
+      name = (name or '') + "'"
+    seen.add(name)
+    objs.append(dict(
+        content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
+    text = text[len(before) + len(content):]
+
+  if text:
+    objs.append(dict(content=text))
+
+  return objs
+
+
+if __name__ == '__main__':
+  # Simple test.
+  print([
+      {
+          k: (v.shape, v.mean()) if isinstance(v, np.ndarray) else v
+          for k, v in obj.items()
+      }
+      for obj in extract_objs(EXAMPLE_STRING, 100, 200)
+  ])
diff --git a/requirements-cpu.txt b/requirements-cpu.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e54c76d0fa705382d9cfb17745a48488830ed081
--- /dev/null
+++ b/requirements-cpu.txt
@@ -0,0 +1,13 @@
+einops
+flax
+gradio
+huggingface-hub
+jax
+jaxlib
+ml_collections
+numpy
+orbax-checkpoint
+Pillow
+psutil
+sentencepiece
+tensorflow
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8fa778537ea48a6629baab79680ab7ec388a0ab3
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,14 @@
+einops
+flax
+gradio
+huggingface-hub
+-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
+jax[cuda12_pip]~=0.4.25
+jaxlib
+ml_collections
+numpy
+orbax-checkpoint
+Pillow
+psutil
+sentencepiece
+tensorflow-cpu
diff --git a/vae-oid.npz b/vae-oid.npz
new file mode 100644
index 0000000000000000000000000000000000000000..e30bd245fc0b67df063c5bd49d83c7130bba2637
--- /dev/null
+++ b/vae-oid.npz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:5586010257b8536dddefab65e7755077f21d5672d5674dacf911f73ae95a4447
+size 8479556