paligemma / app.py
abetlen's picture
Update
7a12ec7
"""PaliGemma demo gradio app."""
import datetime
import functools
import glob
import json
import logging
import os
import time
import gradio as gr
import PIL.Image
import gradio_helpers
import models
import paligemma_parse
INTRO_TEXT = """🤲 PaliGemma GGUF demo\n\n
| [Paper](https://arxiv.org/abs/2407.07726)
| [GitHub](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)
| [HF blog post](https://huggingface.co/blog/paligemma)
| [Google blog post](https://developers.googleblog.com/en/gemma-family-and-toolkit-expansion-io-2024)
| [Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/363)
| [Demo](https://huggingface.co/spaces/google/paligemma)
|\n\n
[PaliGemma](https://ai.google.dev/gemma/docs/paligemma) is an open vision-language model by Google,
inspired by [PaLI-3](https://arxiv.org/abs/2310.09199) and
built with open components such as the [SigLIP](https://arxiv.org/abs/2303.15343)
vision model and the [Gemma](https://arxiv.org/abs/2403.08295) language model. PaliGemma is designed as a versatile
model for transfer to a wide range of vision-language tasks such as image and short video caption, visual question
answering, text reading, object detection and object segmentation.
\n\n
This space includes models fine-tuned on a mix of downstream tasks.
See the [blog post](https://huggingface.co/blog/paligemma) and
[README](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)
for detailed information how to use and fine-tune PaliGemma models.
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""
make_image = lambda value, visible: gr.Image(
value, label='Image', type='filepath', visible=visible)
make_annotated_image = functools.partial(gr.AnnotatedImage, label='Image')
make_highlighted_text = functools.partial(gr.HighlightedText, label='Output')
# https://coolors.co/4285f4-db4437-f4b400-0f9d58-e48ef1
COLORS = ['#4285f4', '#db4437', '#f4b400', '#0f9d58', '#e48ef1']
@gradio_helpers.synced
def compute(image, prompt, model_name, sampler):
"""Runs model inference."""
if image is None:
raise gr.Error('Image required')
logging.info('prompt="%s"', prompt)
if isinstance(image, str):
image = PIL.Image.open(image)
if gradio_helpers.should_mock():
logging.warning('Mocking response')
time.sleep(2.)
output = paligemma_parse.EXAMPLE_STRING
else:
if not model_name:
raise gr.Error('Models not loaded yet')
output = models.generate(model_name, sampler, image, prompt)
# output = 'output'
logging.info('output="%s"', output)
width, height = image.size
objs = paligemma_parse.extract_objs(output, width, height, unique_labels=True)
labels = set(obj.get('name') for obj in objs if obj.get('name'))
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
annotated_image = (
image,
[
(
obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
obj['name'] or '',
)
for obj in objs
if 'mask' in obj or 'xyxy' in obj
],
)
has_annotations = bool(annotated_image[1])
return (
make_highlighted_text(
highlighted_text, visible=True, color_map=color_map),
make_image(image, visible=not has_annotations),
make_annotated_image(
annotated_image, visible=has_annotations, width=width, height=height,
color_map=color_map),
)
def warmup(model_name):
image = PIL.Image.new('RGB', [1, 1])
_ = compute(image, '', model_name, 'greedy')
def reset():
return (
'', make_highlighted_text('', visible=False),
make_image(None, visible=True), make_annotated_image(None, visible=False),
)
def create_app():
"""Creates demo UI."""
make_model = lambda choices: gr.Dropdown(
value=(choices + [''])[0],
choices=choices,
label='Model',
visible=bool(choices),
)
make_prompt = lambda value, visible=True: gr.Textbox(
value, label='Prompt', visible=visible)
with gr.Blocks() as demo:
##### Main UI structure.
gr.Markdown(INTRO_TEXT)
with gr.Row():
image = make_image(None, visible=True) # input
annotated_image = make_annotated_image(None, visible=False) # output
with gr.Column():
with gr.Row():
prompt = make_prompt('', visible=True)
model_info = gr.Markdown(label='Model Info')
with gr.Row():
model = make_model([])
samplers = [
'greedy', 'nucleus(0.1)', 'nucleus(0.3)', 'temperature(0.5)']
sampler = gr.Dropdown(
value=samplers[0], choices=samplers, label='Decoding'
)
with gr.Row():
run = gr.Button('Run', variant='primary')
clear = gr.Button('Clear')
highlighted_text = make_highlighted_text('', visible=False)
##### UI logic.
def update_ui(model, prompt):
prompt = make_prompt(prompt, visible=True)
model_info = f'Model `{model}` – {models.MODELS_INFO.get(model, "No info.")}'
return [prompt, model_info]
gr.on(
[model.change],
update_ui,
[model, prompt],
[prompt, model_info],
)
gr.on(
[run.click, prompt.submit],
compute,
[image, prompt, model, sampler],
[highlighted_text, image, annotated_image],
)
clear.click(
reset, None, [prompt, highlighted_text, image, annotated_image]
)
##### Examples.
gr.set_static_paths(['examples/'])
all_examples = [json.load(open(p)) for p in glob.glob('examples/*.json')]
logging.info('loaded %d examples', len(all_examples))
example_image = gr.Image(
label='Image', visible=False) # proxy, never visible
example_model = gr.Text(
label='Model', visible=False) # proxy, never visible
example_prompt = gr.Text(
label='Prompt', visible=False) # proxy, never visible
example_license = gr.Markdown(
label='Image License', visible=False) # placeholder, never visible
gr.Examples(
examples=[
[
f'examples/{ex["name"]}.jpg',
ex['prompt'],
ex['model'],
ex['license'],
]
for ex in all_examples
if ex['model'] in models.MODELS
],
inputs=[example_image, example_prompt, example_model, example_license],
)
##### Examples UI logic.
example_image.change(
lambda image_path: (
make_image(image_path, visible=True),
make_annotated_image(None, visible=False),
make_highlighted_text('', visible=False),
),
example_image,
[image, annotated_image, highlighted_text],
)
def example_model_changed(model):
if model not in gradio_helpers.get_paths():
raise gr.Error(f'Model "{model}" not loaded!')
return model
example_model.change(example_model_changed, example_model, model)
example_prompt.change(make_prompt, example_prompt, prompt)
##### Status.
status = gr.Markdown(f'Startup: {datetime.datetime.now()}')
# gpu_kind = gr.Markdown(f'GPU=?')
# demo.load(
# lambda: [
# gradio_helpers.get_status(),
# make_model(list(gradio_helpers.get_paths())),
# ],
# None,
# [status, model],
# )
# def get_gpu_kind():
# device = jax.devices()[0]
# if not gradio_helpers.should_mock() and device.platform != 'gpu':
# raise gr.Error('GPU not visible to JAX!')
# return f'GPU={device.device_kind}'
# demo.load(get_gpu_kind, None, gpu_kind)
return demo
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s')
for k, v in os.environ.items():
logging.info('environ["%s"] = %r', k, v)
gradio_helpers.set_warmup_function(warmup)
for name, (repo, filenames) in models.MODELS.items():
gradio_helpers.register_download(name, repo, filenames)
create_app().queue().launch()