{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "118UKH5bWCGa"
},
"source": [
"# DALL·E mini - Inference pipeline\n",
"\n",
"*Generate images from a text prompt*\n",
"\n",
"
\n",
"\n",
"This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.\n",
"\n",
"Just want to play? Use directly [the app](https://www.craiyon.com/).\n",
"\n",
"For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dS8LbaonYm3a"
},
"source": [
"## 🛠️ Installation and set-up"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uzjAM2GBYpZX",
"outputId": "9042b53c-1260-4ae6-ff54-be878c99d505"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"tensorflow-metal 0.5.0 requires six~=1.15.0, but you have six 1.16.0 which is incompatible.\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
}
],
"source": [
"# Install required libraries\n",
"!pip install -q dalle-mini\n",
"!pip install -q git+https://github.com/patil-suraj/vqgan-jax.git"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ozHzTkyv8cqU"
},
"source": [
"We load required models:\n",
"* DALL·E mini for text to encoded images\n",
"* VQGAN for decoding images\n",
"* CLIP for scoring predictions"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "K6CxW2o42f-w"
},
"outputs": [],
"source": [
"# Model references\n",
"\n",
"# dalle-mega\n",
"DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\" # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
"DALLE_COMMIT_ID = None\n",
"\n",
"# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line\n",
"# DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"\n",
"\n",
"# VQGAN model\n",
"VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
"VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Yv-aR3t4Oe5v",
"outputId": "850b9a43-2506-432f-ae8e-b8b2598e4a98"
},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"# check how many devices are available\n",
"jax.local_device_count()"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 240
},
"id": "92zYmvsQ38vL",
"outputId": "556dc277-a885-443b-8848-373696f5acc7"
},
"outputs": [
{
"ename": "NameError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;31m# Load dalle-mini\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m model, params = DalleBart.from_pretrained(\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mDALLE_MODEL\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrevision\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mDALLE_COMMIT_ID\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat16\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_do_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m )\n\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mNameError\u001b[0m: name 'DALLE_MODEL' is not defined"
]
}
],
"source": [
"# Load models & tokenizer\n",
"from dalle_mini import DalleBart, DalleBartProcessor\n",
"from vqgan_jax.modeling_flax_vqgan import VQModel\n",
"from transformers import CLIPProcessor, FlaxCLIPModel\n",
"\n",
"# Load dalle-mini\n",
"model, params = DalleBart.from_pretrained(\n",
" DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
")\n",
"\n",
"# Load VQGAN\n",
"vqgan, vqgan_params = VQModel.from_pretrained(\n",
" VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "o_vH2X1tDtzA"
},
"source": [
"Model parameters are replicated on each device for faster inference."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wtvLoM48EeVw"
},
"outputs": [],
"source": [
"from flax.jax_utils import replicate\n",
"\n",
"params = replicate(params)\n",
"vqgan_params = replicate(vqgan_params)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0A9AHQIgZ_qw"
},
"source": [
"Model functions are compiled and parallelized to take advantage of multiple devices."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sOtoOmYsSYPz"
},
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"# model inference\n",
"@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4, 5, 6))\n",
"def p_generate(\n",
" tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale\n",
"):\n",
" return model.generate(\n",
" **tokenized_prompt,\n",
" prng_key=key,\n",
" params=params,\n",
" top_k=top_k,\n",
" top_p=top_p,\n",
" temperature=temperature,\n",
" condition_scale=condition_scale,\n",
" )\n",
"\n",
"\n",
"# decode image\n",
"@partial(jax.pmap, axis_name=\"batch\")\n",
"def p_decode(indices, params):\n",
" return vqgan.decode_code(indices, params=params)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HmVN6IBwapBA"
},
"source": [
"Keys are passed to the model on each device to generate unique inference per device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4CTXmlUkThhX"
},
"outputs": [],
"source": [
"import random\n",
"\n",
"# create a random key\n",
"seed = random.randint(0, 2**32 - 1)\n",
"key = jax.random.PRNGKey(seed)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BrnVyCo81pij"
},
"source": [
"## 🖍 Text Prompt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rsmj0Aj5OQox"
},
"source": [
"Our model requires processing prompts."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YjjhUychOVxm"
},
"outputs": [],
"source": [
"from dalle_mini import DalleBartProcessor\n",
"\n",
"processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BQ7fymSPyvF_"
},
"source": [
"Let's define some text prompts."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "x_0vI9ge1oKr"
},
"outputs": [],
"source": [
"prompts = [\n",
" \"sunset over a lake in the mountains\",\n",
" \"the Eiffel tower landing on the moon\",\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XlZUG3SCLnGE"
},
"source": [
"Note: we could use the same prompt multiple times for faster inference."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VKjEZGjtO49k"
},
"outputs": [],
"source": [
"tokenized_prompts = processor(prompts)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-CEJBnuJOe5z"
},
"source": [
"Finally we replicate the prompts onto each device."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "lQePgju5Oe5z"
},
"outputs": [],
"source": [
"tokenized_prompt = replicate(tokenized_prompts)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "phQ9bhjRkgAZ"
},
"source": [
"## 🎨 Generate images\n",
"\n",
"We generate images using dalle-mini model and decode them with the VQGAN."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "d0wVkXpKqnHA"
},
"outputs": [],
"source": [
"# number of predictions per prompt\n",
"n_predictions = 8\n",
"\n",
"# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)\n",
"gen_top_k = None\n",
"gen_top_p = None\n",
"temperature = None\n",
"cond_scale = 10.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SDjEx9JxR3v8"
},
"outputs": [],
"source": [
"from flax.training.common_utils import shard_prng_key\n",
"import numpy as np\n",
"from PIL import Image\n",
"from tqdm.notebook import trange\n",
"\n",
"print(f\"Prompts: {prompts}\\n\")\n",
"# generate images\n",
"images = []\n",
"for i in trange(max(n_predictions // jax.device_count(), 1)):\n",
" # get a new key\n",
" key, subkey = jax.random.split(key)\n",
" # generate images\n",
" encoded_images = p_generate(\n",
" tokenized_prompt,\n",
" shard_prng_key(subkey),\n",
" params,\n",
" gen_top_k,\n",
" gen_top_p,\n",
" temperature,\n",
" cond_scale,\n",
" )\n",
" # remove BOS\n",
" encoded_images = encoded_images.sequences[..., 1:]\n",
" # decode images\n",
" decoded_images = p_decode(encoded_images, vqgan_params)\n",
" decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
" for decoded_img in decoded_images:\n",
" img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))\n",
" images.append(img)\n",
" display(img)\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tw02wG9zGmyB"
},
"source": [
"## 🏅 Optional: Rank images by CLIP score\n",
"\n",
"We can rank images according to CLIP.\n",
"\n",
"**Note: your session may crash if you don't have a subscription to Colab Pro.**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RGjlIW_f6GA0"
},
"outputs": [],
"source": [
"# CLIP model\n",
"CLIP_REPO = \"openai/clip-vit-base-patch32\"\n",
"CLIP_COMMIT_ID = None\n",
"\n",
"# Load CLIP\n",
"clip, clip_params = FlaxCLIPModel.from_pretrained(\n",
" CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
")\n",
"clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)\n",
"clip_params = replicate(clip_params)\n",
"\n",
"# score images\n",
"@partial(jax.pmap, axis_name=\"batch\")\n",
"def p_clip(inputs, params):\n",
" logits = clip(params=params, **inputs).logits_per_image\n",
" return logits"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "FoLXpjCmGpju"
},
"outputs": [],
"source": [
"from flax.training.common_utils import shard\n",
"\n",
"# get clip scores\n",
"clip_inputs = clip_processor(\n",
" text=prompts * jax.device_count(),\n",
" images=images,\n",
" return_tensors=\"np\",\n",
" padding=\"max_length\",\n",
" max_length=77,\n",
" truncation=True,\n",
").data\n",
"logits = p_clip(shard(clip_inputs), clip_params)\n",
"\n",
"# organize scores per prompt\n",
"p = len(prompts)\n",
"logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4AAWRm70LgED"
},
"source": [
"Let's now display images ranked by CLIP score."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "zsgxxubLLkIu"
},
"outputs": [],
"source": [
"for i, prompt in enumerate(prompts):\n",
" print(f\"Prompt: {prompt}\\n\")\n",
" for idx in logits[i].argsort()[::-1]:\n",
" display(images[idx * p + i])\n",
" print(f\"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\\n\")\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oZT9i3jCjir0"
},
"source": [
"## 🪄 Optional: Save your Generated Images as W&B Tables\n",
"\n",
"W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-pSiv6Vwjkn0"
},
"outputs": [],
"source": [
"import wandb\n",
"\n",
"# Initialize a W&B run.\n",
"project = 'dalle-mini-tables-colab'\n",
"run = wandb.init(project=project)\n",
"\n",
"# Initialize an empty W&B Tables.\n",
"columns = [\"captions\"] + [f\"image_{i+1}\" for i in range(n_predictions)]\n",
"gen_table = wandb.Table(columns=columns)\n",
"\n",
"# Add data to the table.\n",
"for i, prompt in enumerate(prompts):\n",
" # If CLIP scores exist, sort the Images\n",
" if logits is not None:\n",
" idxs = logits[i].argsort()[::-1]\n",
" tmp_imgs = images[i::len(prompts)]\n",
" tmp_imgs = [tmp_imgs[idx] for idx in idxs]\n",
" else:\n",
" tmp_imgs = images[i::len(prompts)]\n",
"\n",
" # Add the data to the table.\n",
" gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])\n",
"\n",
"# Log the Table to W&B dashboard.\n",
"wandb.log({\"Generated Images\": gen_table})\n",
"\n",
"# Close the W&B run.\n",
"run.finish()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ck2ZnHwVjnRd"
},
"source": [
"Click on the link above to check out your generated images."
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"machine_shape": "hm",
"name": "DALL·E mini - Inference pipeline.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3.9.13 ('base')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "3e91440bae70fe36b08f2decfecf198c5281689ed89adf5e1c2c93a1bdd6e28e"
}
}
},
"nbformat": 4,
"nbformat_minor": 0
}