{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "UViM panoptic task", "provenance": [], "collapsed_sections": [], "private_outputs": true }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard" }, "cells": [ { "cell_type": "code", "source": [ "# Fetch big_vision repository and move it into the current workdir (import path).\n", "!git clone --depth=1 https://github.com/google-research/big_vision big_vision_repo\n", "!cp -R big_vision_repo/big_vision big_vision\n", "!pip install -qr big_vision/requirements.txt" ], "metadata": { "id": "sKZK6_QpVI_O" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "import jax\n", "import jax.numpy as jnp\n", "import numpy as np\n", "\n", "from big_vision.models.proj.uvim import vtt # stage-II model\n", "from big_vision.models.proj.uvim import vit # stage-I model\n", "\n", "from big_vision.models.proj.uvim import decode\n", "from big_vision.trainers.proj.uvim import panoptic_task as task\n", "from big_vision.configs.proj.uvim import train_coco_panoptic_pretrained as config_module\n", "\n", "import big_vision.pp.ops_image\n", "import big_vision.pp.ops_general\n", "import big_vision.pp.proj.uvim.pp_ops\n", "from big_vision.pp import builder as pp_builder\n", "\n", "config = config_module.get_config()\n", "res = 512\n", "seq_len = config.model.seq_len\n", "\n", "lm_model = vtt.Model(**config.model)\n", "oracle_model = vit.Model(**config.oracle.model)\n", "\n", "preprocess_fn = pp_builder.get_preprocess_fn(\n", " 'decode|resize(512)|value_range(-1,1)|'\n", " 'copy(inkey=\"image\",outkey=\"image_ctx\")')\n", "\n", "@jax.jit\n", "def predict_code(params, x, rng, temperature):\n", " prompts = jnp.zeros((x[\"image\"].shape[0], seq_len), dtype=jnp.int32)\n", " seqs, _, _ = decode.temperature_sampling(\n", " params=params, model=lm_model, seed=rng,\n", " inputs=x[\"image\"],\n", " prompts=prompts,\n", " temperature=temperature,\n", " num_samples=1, eos_token=-1, prefill=False)\n", " seqs = jnp.squeeze(seqs, axis=1) # drop num_samples axis \n", " return seqs - 1\n", " \n", "@jax.jit\n", "def labels2code(params, x, ctx):\n", " y, aux = oracle_model.apply(params, x, ctx=ctx, train=False, method=oracle_model.encode)\n", " return aux[\"code\"]\n", "\n", "@jax.jit\n", "def code2labels(params, code, ctx):\n", " logits, aux = oracle_model.apply(params, code, ctx=ctx, train=False, discrete_input=True, method=oracle_model.decode)\n", " return task.predict_outputs(logits, config.oracle)" ], "metadata": { "id": "QzThueWDzc7I" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Load checkpoints\n", "!gsutil cp -n gs://big_vision/uvim/panoptic_stageI_params.npz gs://big_vision/uvim/panoptic_stageII_params.npz .\n", "\n", "oracle_params, oracle_state = vit.load(None, \"panoptic_stageI_params.npz\")\n", "oracle_params = jax.device_put({\"params\": oracle_params, \"state\": oracle_state})\n", "\n", "lm_params = vtt.load(None, \"panoptic_stageII_params.npz\")\n", "lm_params = jax.device_put({\"params\": lm_params})" ], "metadata": { "id": "AEjRgshLa6Fp" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Prepare set of images from coco/val2017:\n", "# - https://cocodataset.org/\n", "import os\n", "import tensorflow as tf\n", "\n", "if not os.path.exists(\"val2017/\"):\n", " !wget --no-clobber http://images.cocodataset.org/zips/val2017.zip\n", " !unzip -uq val2017.zip\n", " !wget -c https://raw.githubusercontent.com/cocodataset/panopticapi/master/panoptic_coco_categories.json\n", "\n", "dataset = tf.data.Dataset.list_files(\"val2017/*.jpg\", shuffle=True)\n", "dataset = dataset.map(lambda filename: {\"image\": tf.io.read_file(filename)})\n", "dataset = dataset.map(preprocess_fn)" ], "metadata": { "id": "k2ArKPlFQVcz" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Run the model in a few examples:\n", "from matplotlib import pyplot as plt\n", "from matplotlib import patches\n", "from big_vision.trainers.proj.uvim import coco_utils\n", "\n", "num_examples = 4\n", "data = dataset.batch(1).take(num_examples).as_numpy_iterator()\n", "key = jax.random.PRNGKey(0)\n", "temperature = jnp.array(1e-7)\n", "\n", "def render_example(image, prediction, with_legend=True):\n", " f, ax = plt.subplots(1, 2, figsize=(10, 10))\n", " ax[0].imshow(image*0.5 + 0.5)\n", " ax[0].axis(\"off\")\n", "\n", " rgb, info = coco_utils.rgb_panoptic_from_twochannels(prediction, boundaries=True)\n", " ax[1].matshow(rgb)\n", " ax[1].axis(\"off\")\n", "\n", " if with_legend:\n", " handles = []\n", " for instance in info.values():\n", " handles.append(patches.Patch(\n", " facecolor=np.array(instance[\"color\"])/255.0,\n", " edgecolor='black', label=instance[\"name\"]))\n", " ax[1].legend(handles=handles, loc=(1.04, 0.0));\n", "\n", "\n", "for idx, batch in enumerate(data):\n", " subkey = jax.random.fold_in(key, idx)\n", " code = predict_code(lm_params, batch, key, temperature)\n", " aux_inputs = task.input_pp(batch, config.oracle)\n", " prediction = code2labels(oracle_params, code, aux_inputs[\"ctx\"])\n", " render_example(batch[\"image\"][0], prediction[0])" ], "metadata": { "id": "TuevCy33nuv3" }, "execution_count": null, "outputs": [] } ] }