"""Gradio clone of https://google-research.github.io/vision_transformer/lit/. Features: - Models are downloaded dynamically. - Models are cached on local disk, and in RAM. - Progress bars when downloading/reading/computing. - Dynamic update of model controls. - Dynamic generation of output sliders. - Use of `gr.State()` for better use of progress bars. """ import dataclasses import functools import json import logging import os import time import urllib.request import gradio as gr import PIL.Image # pylint: disable=g-bad-import-order import big_vision_contrastive_models as models import gradio_helpers INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json' IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg' MAX_ANSWERS = 10 MAX_DISK_CACHE = 20e9 MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10} # family/variant/res -> name MODEL_MAP = { 'lit': { 'B/16': { 224: 'lit_b16b', }, 'L/16': { 224: 'lit_l16l', }, }, 'siglip': { 'B/16': { 224: 'siglip_b16b_224', 256: 'siglip_b16b_256', 384: 'siglip_b16b_384', 512: 'siglip_b16b_512', }, 'L/16': { 256: 'siglip_l16l_256', 384: 'siglip_l16l_384', }, 'So400m/14': { 224: 'siglip_so400m14so440m_224', 384: 'siglip_so400m14so440m_384', }, }, } def get_cache_status(): """Returns a string summarizing cache status.""" mem_n, mem_sz = gradio_helpers.get_memory_cache_info() disk_n, disk_sz = gradio_helpers.get_disk_cache_info() return ( f'memory cache {mem_n} items [{mem_sz/1e9:.2f}G], ' f'disk cache {disk_n} items [{disk_sz/1e9:.2f}G]' ) def compute( image_path, prompts, family, variant, res, bias, progress=gr.Progress() ): """Loads model and computes answers.""" if image_path is None: raise gr.Error('Must first select an image!') t0 = time.monotonic() model_name = MODEL_MAP[family][variant][res] config = models.MODEL_CONFIGS[model_name] local_ckpt = gradio_helpers.get_disk_cache( config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE) config = dataclasses.replace(config, ckpt=local_ckpt) params, model = gradio_helpers.get_memory_cache( config, lambda: models.load_model(config), max_cache_size_bytes=MAX_RAM_CACHE, progress=progress, estimated_secs={ ('lit', 'B/16'): 1, ('lit', 'L/16'): 2.5, ('siglip', 'B/16'): 9, ('siglip', 'L/16'): 28, ('siglip', 'So400m/14'): 36, }.get((family, variant)) ) model: models.ContrastiveModel = model it = progress.tqdm(list(range(3)), desc='compute') logging.info('Opening image "%s"', image_path) with gradio_helpers.timed(f'opening image "{image_path}"'): image = PIL.Image.open(image_path) next(it) with gradio_helpers.timed('image features'): zimg, unused_out = model.embed_images( params, model.preprocess_images([image]) ) next(it) with gradio_helpers.timed('text features'): prompts = prompts.split('\n') ztxt, out = model.embed_texts( params, model.preprocess_texts(prompts) ) next(it) t = model.get_temperature(out) text_probs = [] if family == 'lit': text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0]) elif family == 'siglip': text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0]) state = list(zip(prompts, [round(p.item(), 3) for p in text_probs])) dt = time.monotonic() - t0 status = gr.Markdown( f'Computed inference in {dt:.1f} seconds ({get_cache_status()})') if 'b' in out: logging.info('model_name=%s default bias=%f', model_name, out['b']) return status, state def update_answers(state): """Generates visible sliders for answers.""" answers = [] for prompt, prob in state[:MAX_ANSWERS]: answers.append( gr.Slider(value=round(100*prob, 2), label=prompt, visible=True)) while len(answers) < MAX_ANSWERS: answers.append(gr.Slider(visible=False)) return answers def create_app(): """Creates demo UI.""" css = ''' .slider input[type="number"] { width: 5em; } #examples td.textbox > div { white-space: pre-wrap !important; text-align: left; } ''' with gr.Blocks(css=css) as demo: gr.Markdown( 'Gradio clone of the original ' '[LiT demo](https://google-research.github.io/vision_transformer/lit/).' ) status = gr.Markdown(f'Ready ({get_cache_status()})') with gr.Row(): image = gr.Image(label='Image', type='filepath') source = gr.Markdown('', visible=False) state = gr.State([]) with gr.Column(): prompts = gr.Textbox( label='Prompts (press Shift-ENTER to add a prompt)') with gr.Row(): family = gr.Dropdown( value='lit', choices=list(MODEL_MAP), label='Model family') make_variant = functools.partial(gr.Dropdown, label='Variant') variant = make_variant(list(MODEL_MAP['lit']), value='B/16') make_res = functools.partial(gr.Dropdown, label='Resolution') res = make_res(list(MODEL_MAP['lit']['B/16']), value=224) def make_bias(family, variant, res): visible = family == 'siglip' value = { ('siglip', 'B/16', 224): -12.9, ('siglip', 'L/16', 256): -12.7, ('siglip', 'L/16', 256): -16.5, # ... }.get((family, variant, res), -10.0) return gr.Slider( value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible, ) bias = make_bias(family.value, variant.value, res.value) def update_inputs(family, variant, res): d = MODEL_MAP[family] variants = list(d) variant = variant if variant in variants else variants[0] d = d[variant] ress = list(d) res = res if res in ress else ress[0] return [ make_variant(variants, value=variant), make_res(ress, value=res), make_bias(family, variant, res), ] gr.on( [family.change, variant.change, res.change], update_inputs, [family, variant, res], [variant, res, bias], ) # (end of code for reactive UI) run = gr.Button('Run') answers = [ # Will be set to visible in `update_answers()`. gr.Slider(0, 100, 0, visible=False, elem_classes='slider') for _ in range(MAX_ANSWERS) ] # We want to avoid showing multiple progress bars, so we only update # a single `status` widget here, and store the computed information in # `state`... run.click( fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state], ) # ... then we use `state` to update UI components without showing a # progress bar in their place. status.change(fn=update_answers, inputs=state, outputs=answers) info = json.load(urllib.request.urlopen(INFO_URL)) gr.Markdown('Note: below images have 224 px resolution only:') gr.Examples( examples=[ [ IMG_URL_FMT.format(ex['id']), ex['prompts'].replace(', ', '\n'), '[source](%s)' % ex['source'], ] for ex in info ], inputs=[image, prompts, source, license], outputs=answers, elem_id='examples', ) return demo if __name__ == '__main__': logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') for k, v in os.environ.items(): logging.info('environ["%s"] = %r', k, v) models.setup() create_app().queue().launch()