In [None]:
# Fetch big_vision repository and move it into the current workdir (import path).
!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo
!cp -R big_vision_repo/big_vision big_vision
!pip install -qr big_vision/requirements.txt

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from big_vision.models.proj.uvim import vtt # stage-II model
from big_vision.models.proj.uvim import vit # stage-I model

from big_vision.models.proj.uvim import decode
from big_vision.trainers.proj.uvim import panoptic_task as task
from big_vision.configs.proj.uvim import train_coco_panoptic_pretrained as config_module

import big_vision.pp.ops_image
import big_vision.pp.ops_general
import big_vision.pp.proj.uvim.pp_ops
from big_vision.pp import builder as pp_builder

config = config_module.get_config()
res = 512
seq_len = config.model.seq_len

lm_model = vtt.Model(**config.model)
oracle_model = vit.Model(**config.oracle.model)

preprocess_fn = pp_builder.get_preprocess_fn(
 'decode|resize(512)|value_range(-1,1)|'
 'copy(inkey="image",outkey="image_ctx")')

@jax.jit
def predict_code(params, x, rng, temperature):
 prompts = jnp.zeros((x["image"].shape[0], seq_len), dtype=jnp.int32)
 seqs, _, _ = decode.temperature_sampling(
 params=params, model=lm_model, seed=rng,
 inputs=x["image"],
 prompts=prompts,
 temperature=temperature,
 num_samples=1, eos_token=-1, prefill=False)
 seqs = jnp.squeeze(seqs, axis=1) # drop num_samples axis 
 return seqs - 1
 
@jax.jit
def labels2code(params, x, ctx):
 y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)
 return aux["code"]

@jax.jit
def code2labels(params, code, ctx):
 logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)
 return task.predict_outputs(logits, config.oracle)

In [None]:
# Load checkpoints
!gsutil cp -n gs://big_vision/uvim/panoptic_stageI_params.npz gs://big_vision/uvim/panoptic_stageII_params.npz .

oracle_params, oracle_state = vit.load(None, "panoptic_stageI_params.npz")
oracle_params = jax.device_put({"params": oracle_params, "state": oracle_state})

lm_params = vtt.load(None, "panoptic_stageII_params.npz")
lm_params = jax.device_put({"params": lm_params})

In [None]:
# Prepare set of images from coco/val2017:
# - https://cocodataset.org/
import os
import tensorflow as tf

if not os.path.exists("val2017/"):
 !wget --no-clobber http://images.cocodataset.org/zips/val2017.zip
 !unzip -uq val2017.zip
 !wget -c https://raw.githubusercontent.com/cocodataset/panopticapi/master/panoptic_coco_categories.json

dataset = tf.data.Dataset.list_files("val2017/*.jpg", shuffle=True)
dataset = dataset.map(lambda filename: {"image": tf.io.read_file(filename)})
dataset = dataset.map(preprocess_fn)

In [None]:
# Run the model in a few examples:
from matplotlib import pyplot as plt
from matplotlib import patches
from big_vision.trainers.proj.uvim import coco_utils

num_examples = 4
data = dataset.batch(1).take(num_examples).as_numpy_iterator()
key = jax.random.PRNGKey(0)
temperature = jnp.array(1e-7)

def render_example(image, prediction, with_legend=True):
 f, ax = plt.subplots(1, 2, figsize=(10, 10))
 ax[0].imshow(image*0.5 + 0.5)
 ax[0].axis("off")

 rgb, info = coco_utils.rgb_panoptic_from_twochannels(prediction, boundaries=True)
 ax[1].matshow(rgb)
 ax[1].axis("off")

 if with_legend:
 handles = []
 for instance in info.values():
 handles.append(patches.Patch(
 facecolor=np.array(instance["color"])/255.0,
 edgecolor='black', label=instance["name"]))
 ax[1].legend(handles=handles, loc=(1.04, 0.0));


for idx, batch in enumerate(data):
 subkey = jax.random.fold_in(key, idx)
 code = predict_code(lm_params, batch, key, temperature)
 aux_inputs = task.input_pp(batch, config.oracle)
 prediction = code2labels(oracle_params, code, aux_inputs["ctx"])
 render_example(batch["image"][0], prediction[0])