{ "cells": [ { "cell_type": "markdown", "id": "cffb778a-c9d3-40c8-95ff-09c6fb29dba2", "metadata": {}, "source": [ "# Image generation with Stable Cascade and OpenVINO\n", "\n", "[Stable Cascade](https://huggingface.co/stabilityai/stable-cascade) is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this important? The smaller the latent space, the faster you can run inference and the cheaper the training becomes. How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a 1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the highly compressed latent space.\n", "\n", " " ] }, { "cell_type": "markdown", "id": "4e96a42f-ed4f-40f2-8220-0536350cc763", "metadata": {}, "source": [ "#### Table of contents:\n", "- [Prerequisites](#Prerequisites)\n", "- [Load the original model](#Load-the-original-model)\n", " - [Infer the original model](#Infer-the-original-model)\n", "- [Convert the model to OpenVINO IR](#Convert-the-model-to-OpenVINO-IR)\n", " - [Prior pipeline](#Prior-pipeline)\n", " - [Decoder pipeline](#Decoder-pipeline)\n", "- [Select inference device](#Select-inference-device)\n", "- [Building the pipeline](#Building-the-pipeline)\n", "- [Inference](#Inference)\n", "- [Interactive inference](#Interactive-inference)" ] }, { "cell_type": "markdown", "id": "25180082-55c7-4c97-b739-3d0adbf8ec7b", "metadata": {}, "source": [ "## Prerequisites\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 1, "id": "17ea91e0-b723-4e87-8e9d-7aeb84bcb51e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Note: you may need to restart the kernel to use updated packages.\n" ] } ], "source": [ "%pip install -q \"diffusers>=0.27.0\" accelerate datasets gradio transformers \"nncf>=2.10.0\" \"openvino>=2024.1.0\" \"torch>=2.1\" --extra-index-url https://download.pytorch.org/whl/cpu" ] }, { "cell_type": "markdown", "id": "23217887-3bd1-4e3f-9be1-dfdca41fdb83", "metadata": {}, "source": [ "## Load and run the original pipeline\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 2, "id": "55fb7ec5cee19217", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "56198ad1c3304fc4ab7348fd18be7516", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading pipeline components...: 0%| | 0/6 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "if run_original_inference.value:\n", " prior.to(torch.device(\"cpu\"))\n", " prior_output = prior(\n", " prompt=prompt,\n", " height=1024,\n", " width=1024,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=4.0,\n", " num_images_per_prompt=1,\n", " num_inference_steps=20,\n", " )\n", "\n", " decoder_output = decoder(\n", " image_embeddings=prior_output.image_embeddings,\n", " prompt=prompt,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=0.0,\n", " output_type=\"pil\",\n", " num_inference_steps=10,\n", " ).images[0]\n", " display(decoder_output)" ] }, { "cell_type": "markdown", "id": "fd956238-6345-40c9-95cf-dbabf5d14480", "metadata": {}, "source": [ "## Convert the model to OpenVINO IR\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Stable Cascade has 2 components:\n", "- Prior stage `prior`: create low-dimensional latent space representation of the image using text-conditional LDM\n", "- Decoder stage `decoder`: using representation from Prior Stage, produce a latent image in latent space of higher dimensionality using LDM and using VQGAN-decoder, decode the latent image to yield a full-resolution output image." ] }, { "cell_type": "markdown", "id": "2ae75ecb-5909-4179-adf0-6a4317a0df7f", "metadata": {}, "source": [ "Let's define the conversion function for PyTorch modules. We use `ov.convert_model` function to obtain OpenVINO Intermediate Representation object and `ov.save_model` function to save it as XML file. We use `nncf.compress_weights` to [compress model weights](https://docs.openvino.ai/2024/openvino-workflow/model-optimization-guide/weight-compression.html#compress-model-weights) to 8-bit to reduce model size." ] }, { "cell_type": "code", "execution_count": 5, "id": "bd09fb42-ac00-43bc-b0fd-9cb7aea2a2a5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:nncf:NNCF initialized successfully. Supported frameworks detected: torch, openvino\n" ] } ], "source": [ "import gc\n", "from pathlib import Path\n", "\n", "import openvino as ov\n", "import nncf\n", "\n", "\n", "MODELS_DIR = Path(\"models\")\n", "\n", "\n", "def convert(model: torch.nn.Module, xml_path: str, example_input, input_shape=None):\n", " xml_path = Path(xml_path)\n", " if not xml_path.exists():\n", " model.eval()\n", " xml_path.parent.mkdir(parents=True, exist_ok=True)\n", " with torch.no_grad():\n", " if not input_shape:\n", " converted_model = ov.convert_model(model, example_input=example_input)\n", " else:\n", " converted_model = ov.convert_model(model, example_input=example_input, input=input_shape)\n", " converted_model = nncf.compress_weights(converted_model)\n", " ov.save_model(converted_model, xml_path)\n", " del converted_model\n", "\n", " # cleanup memory\n", " torch._C._jit_clear_class_registry()\n", " torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()\n", " torch.jit._state._clear_class_state()\n", "\n", " gc.collect()" ] }, { "cell_type": "markdown", "id": "4d6074b8-938c-414e-9ac5-665c78e44db0", "metadata": {}, "source": [ "### Prior pipeline\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "This pipeline consists of text encoder and prior diffusion model. From here, we always use fixed shapes in conversion by using an `input_shape` parameter to generate a less memory-demanding model." ] }, { "cell_type": "code", "execution_count": 6, "id": "52b96d00-8f89-4e5a-96ba-f7b03fffa349", "metadata": {}, "outputs": [], "source": [ "PRIOR_TEXT_ENCODER_OV_PATH = MODELS_DIR / \"prior_text_encoder_model.xml\"\n", "\n", "prior.text_encoder.config.output_hidden_states = True\n", "\n", "\n", "class TextEncoderWrapper(torch.nn.Module):\n", " def __init__(self, text_encoder):\n", " super().__init__()\n", " self.text_encoder = text_encoder\n", "\n", " def forward(self, input_ids, attention_mask):\n", " outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True)\n", " return outputs[\"text_embeds\"], outputs[\"last_hidden_state\"], outputs[\"hidden_states\"]\n", "\n", "\n", "convert(\n", " TextEncoderWrapper(prior.text_encoder),\n", " PRIOR_TEXT_ENCODER_OV_PATH,\n", " example_input={\n", " \"input_ids\": torch.zeros(1, 77, dtype=torch.int32),\n", " \"attention_mask\": torch.zeros(1, 77),\n", " },\n", " input_shape={\"input_ids\": ((1, 77),), \"attention_mask\": ((1, 77),)},\n", ")\n", "del prior.text_encoder\n", "gc.collect();" ] }, { "cell_type": "code", "execution_count": 7, "id": "4949f675-4154-4d19-9c9a-450bea5fd3c0", "metadata": {}, "outputs": [], "source": [ "PRIOR_PRIOR_MODEL_OV_PATH = MODELS_DIR / \"prior_prior_model.xml\"\n", "\n", "convert(\n", " prior.prior,\n", " PRIOR_PRIOR_MODEL_OV_PATH,\n", " example_input={\n", " \"sample\": torch.zeros(2, 16, 24, 24),\n", " \"timestep_ratio\": torch.ones(2),\n", " \"clip_text_pooled\": torch.zeros(2, 1, 1280),\n", " \"clip_text\": torch.zeros(2, 77, 1280),\n", " \"clip_img\": torch.zeros(2, 1, 768),\n", " },\n", " input_shape=[((-1, 16, 24, 24),), ((-1),), ((-1, 1, 1280),), ((-1, 77, 1280),), (-1, 1, 768)],\n", ")\n", "del prior.prior\n", "gc.collect();" ] }, { "cell_type": "markdown", "id": "07484fbe-7973-4474-a223-319041a06b6e", "metadata": {}, "source": [ "### Decoder pipeline\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Decoder pipeline consists of 3 parts: decoder, text encoder and VQGAN." ] }, { "cell_type": "code", "execution_count": 8, "id": "a0586c95-197b-4985-8201-1d96317843cb", "metadata": {}, "outputs": [], "source": [ "DECODER_TEXT_ENCODER_MODEL_OV_PATH = MODELS_DIR / \"decoder_text_encoder_model.xml\"\n", "\n", "convert(\n", " TextEncoderWrapper(decoder.text_encoder),\n", " DECODER_TEXT_ENCODER_MODEL_OV_PATH,\n", " example_input={\n", " \"input_ids\": torch.zeros(1, 77, dtype=torch.int32),\n", " \"attention_mask\": torch.zeros(1, 77),\n", " },\n", " input_shape={\"input_ids\": ((1, 77),), \"attention_mask\": ((1, 77),)},\n", ")\n", "\n", "del decoder.text_encoder\n", "gc.collect();" ] }, { "cell_type": "code", "execution_count": 9, "id": "d1393bd1-97e2-4f78-9602-82d473b6fabf", "metadata": {}, "outputs": [], "source": [ "DECODER_DECODER_MODEL_OV_PATH = MODELS_DIR / \"decoder_decoder_model.xml\"\n", "\n", "convert(\n", " decoder.decoder,\n", " DECODER_DECODER_MODEL_OV_PATH,\n", " example_input={\n", " \"sample\": torch.zeros(1, 4, 256, 256),\n", " \"timestep_ratio\": torch.ones(1),\n", " \"clip_text_pooled\": torch.zeros(1, 1, 1280),\n", " \"effnet\": torch.zeros(1, 16, 24, 24),\n", " },\n", " input_shape=[((-1, 4, 256, 256),), ((-1),), ((-1, 1, 1280),), ((-1, 16, 24, 24),)],\n", ")\n", "del decoder.decoder\n", "gc.collect();" ] }, { "cell_type": "code", "execution_count": 10, "id": "3b2c7a98-0731-4552-a012-39434d25c268", "metadata": {}, "outputs": [], "source": [ "VQGAN_PATH = MODELS_DIR / \"vqgan_model.xml\"\n", "\n", "\n", "class VqganDecoderWrapper(torch.nn.Module):\n", " def __init__(self, vqgan):\n", " super().__init__()\n", " self.vqgan = vqgan\n", "\n", " def forward(self, h):\n", " return self.vqgan.decode(h)\n", "\n", "\n", "convert(\n", " VqganDecoderWrapper(decoder.vqgan),\n", " VQGAN_PATH,\n", " example_input=torch.zeros(1, 4, 256, 256),\n", " input_shape=(1, 4, 256, 256),\n", ")\n", "del decoder.vqgan\n", "gc.collect();" ] }, { "cell_type": "markdown", "id": "d63a662d-8dac-4698-97d3-80d7616978bb", "metadata": {}, "source": [ "## Select inference device\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Select device from dropdown list for running inference using OpenVINO." ] }, { "cell_type": "code", "execution_count": 11, "id": "72eaf0d9-d7e6-4867-9c9f-35959d7df6d4", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "121601987e6946fd8de4732b8df36915", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Dropdown(description='Device:', index=4, options=('CPU', 'GPU.0', 'GPU.1', 'GPU.2', 'AUTO'), value='AUTO')" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "core = ov.Core()\n", "device = widgets.Dropdown(\n", " options=core.available_devices + [\"AUTO\"],\n", " value=\"AUTO\",\n", " description=\"Device:\",\n", " disabled=False,\n", ")\n", "\n", "device" ] }, { "cell_type": "markdown", "id": "a682f257-4d5e-473b-ba43-8aba1be04af8", "metadata": {}, "source": [ "## Building the pipeline\n", "[back to top ⬆️](#Table-of-contents:)\n", "\n", "Let's create callable wrapper classes for compiled models to allow interaction with original pipelines. Note that all of wrapper classes return `torch.Tensor`s instead of `np.array`s." ] }, { "cell_type": "code", "execution_count": 12, "id": "9dfac40c-12a4-4849-89b4-56fe2e3be049", "metadata": {}, "outputs": [], "source": [ "from collections import namedtuple\n", "\n", "\n", "BaseModelOutputWithPooling = namedtuple(\"BaseModelOutputWithPooling\", [\"text_embeds\", \"last_hidden_state\", \"hidden_states\"])\n", "\n", "\n", "class TextEncoderWrapper:\n", " dtype = torch.float32 # accessed in the original workflow\n", "\n", " def __init__(self, text_encoder_path, device):\n", " self.text_encoder = core.compile_model(text_encoder_path, device.value)\n", "\n", " def __call__(self, input_ids, attention_mask, output_hidden_states=True):\n", " output = self.text_encoder({\"input_ids\": input_ids, \"attention_mask\": attention_mask})\n", " text_embeds = output[0]\n", " last_hidden_state = output[1]\n", " hidden_states = list(output.values())[1:]\n", " return BaseModelOutputWithPooling(torch.from_numpy(text_embeds), torch.from_numpy(last_hidden_state), [torch.from_numpy(hs) for hs in hidden_states])" ] }, { "cell_type": "code", "execution_count": 13, "id": "271b7d3a-6517-48d5-b04f-9104547cf9b8", "metadata": {}, "outputs": [], "source": [ "class PriorPriorWrapper:\n", " def __init__(self, prior_path, device):\n", " self.prior = core.compile_model(prior_path, device.value)\n", " self.config = namedtuple(\"PriorWrapperConfig\", [\"clip_image_in_channels\", \"in_channels\"])(768, 16) # accessed in the original workflow\n", " self.parameters = lambda: (torch.zeros(i, dtype=torch.float32) for i in range(1)) # accessed in the original workflow\n", "\n", " def __call__(self, sample, timestep_ratio, clip_text_pooled, clip_text=None, clip_img=None, **kwargs):\n", " inputs = {\n", " \"sample\": sample,\n", " \"timestep_ratio\": timestep_ratio,\n", " \"clip_text_pooled\": clip_text_pooled,\n", " \"clip_text\": clip_text,\n", " \"clip_img\": clip_img,\n", " }\n", " output = self.prior(inputs)\n", " return [torch.from_numpy(output[0])]" ] }, { "cell_type": "code", "execution_count": 14, "id": "4908718f-10f9-4c57-bce2-32c100b00d04", "metadata": {}, "outputs": [], "source": [ "class DecoderWrapper:\n", " dtype = torch.float32 # accessed in the original workflow\n", "\n", " def __init__(self, decoder_path, device):\n", " self.decoder = core.compile_model(decoder_path, device.value)\n", "\n", " def __call__(self, sample, timestep_ratio, clip_text_pooled, effnet, **kwargs):\n", " inputs = {\"sample\": sample, \"timestep_ratio\": timestep_ratio, \"clip_text_pooled\": clip_text_pooled, \"effnet\": effnet}\n", " output = self.decoder(inputs)\n", " return [torch.from_numpy(output[0])]" ] }, { "cell_type": "code", "execution_count": 15, "id": "e5bf244b-b42e-43ff-a259-6ce7188dcb62", "metadata": {}, "outputs": [], "source": [ "VqganOutput = namedtuple(\"VqganOutput\", \"sample\")\n", "\n", "\n", "class VqganWrapper:\n", " config = namedtuple(\"VqganWrapperConfig\", \"scale_factor\")(0.3764) # accessed in the original workflow\n", "\n", " def __init__(self, vqgan_path, device):\n", " self.vqgan = core.compile_model(vqgan_path, device.value)\n", "\n", " def decode(self, h):\n", " output = self.vqgan(h)[0]\n", " output = torch.tensor(output)\n", " return VqganOutput(output)" ] }, { "cell_type": "markdown", "id": "b3be2d81-c2dd-42f8-9e11-3e7de76e0633", "metadata": {}, "source": [ "And insert wrappers instances in the pipeline:" ] }, { "cell_type": "code", "execution_count": 16, "id": "5c101a11-bf2c-4ad5-8f84-5243dd27adcd", "metadata": {}, "outputs": [], "source": [ "prior.text_encoder = TextEncoderWrapper(PRIOR_TEXT_ENCODER_OV_PATH, device)\n", "prior.prior = PriorPriorWrapper(PRIOR_PRIOR_MODEL_OV_PATH, device)\n", "decoder.decoder = DecoderWrapper(DECODER_DECODER_MODEL_OV_PATH, device)\n", "decoder.text_encoder = TextEncoderWrapper(DECODER_TEXT_ENCODER_MODEL_OV_PATH, device)\n", "decoder.vqgan = VqganWrapper(VQGAN_PATH, device)" ] }, { "cell_type": "markdown", "id": "e20e400b-1ff9-4112-b2fc-b5294138e460", "metadata": {}, "source": [ "## Inference\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 17, "id": "cdf74e18-6a04-462a-89ec-265f9d70ed12", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d52472a68e4c4080a9a3ef9c0bfd4089", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/20 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "prior_output = prior(\n", " prompt=prompt,\n", " height=1024,\n", " width=1024,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=4.0,\n", " num_images_per_prompt=1,\n", " num_inference_steps=20,\n", ")\n", "\n", "decoder_output = decoder(\n", " image_embeddings=prior_output.image_embeddings,\n", " prompt=prompt,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=0.0,\n", " output_type=\"pil\",\n", " num_inference_steps=10,\n", ").images[0]\n", "display(decoder_output)" ] }, { "cell_type": "markdown", "id": "368863cf-fc61-4b8f-9ea5-cb5216916002", "metadata": {}, "source": [ "## Interactive inference\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": 20, "id": "db61d980-d388-4f8d-9fa4-ef71315cdf7c", "metadata": {}, "outputs": [], "source": [ "def generate(prompt, negative_prompt, prior_guidance_scale, decoder_guidance_scale, seed):\n", " generator = torch.Generator().manual_seed(seed)\n", " prior_output = prior(\n", " prompt=prompt,\n", " height=1024,\n", " width=1024,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=prior_guidance_scale,\n", " num_images_per_prompt=1,\n", " num_inference_steps=20,\n", " generator=generator,\n", " )\n", "\n", " decoder_output = decoder(\n", " image_embeddings=prior_output.image_embeddings,\n", " prompt=prompt,\n", " negative_prompt=negative_prompt,\n", " guidance_scale=decoder_guidance_scale,\n", " output_type=\"pil\",\n", " num_inference_steps=10,\n", " generator=generator,\n", " ).images[0]\n", "\n", " return decoder_output" ] }, { "cell_type": "code", "execution_count": null, "id": "b1462649-0396-48b7-916d-950d1254f4c2", "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "import numpy as np\n", "\n", "\n", "demo = gr.Interface(\n", " generate,\n", " [\n", " gr.Textbox(label=\"Prompt\"),\n", " gr.Textbox(label=\"Negative prompt\"),\n", " gr.Slider(\n", " 0,\n", " 20,\n", " step=1,\n", " label=\"Prior guidance scale\",\n", " info=\"Higher guidance scale encourages to generate images that are closely \"\n", " \"linked to the text `prompt`, usually at the expense of lower image quality. Applies to the prior pipeline\",\n", " ),\n", " gr.Slider(\n", " 0,\n", " 20,\n", " step=1,\n", " label=\"Decoder guidance scale\",\n", " info=\"Higher guidance scale encourages to generate images that are closely \"\n", " \"linked to the text `prompt`, usually at the expense of lower image quality. Applies to the decoder pipeline\",\n", " ),\n", " gr.Slider(0, np.iinfo(np.int32).max, label=\"Seed\", step=1),\n", " ],\n", " \"image\",\n", " examples=[[\"An image of a shiba inu, donning a spacesuit and helmet\", \"\", 4, 0, 0], [\"An armchair in the shape of an avocado\", \"\", 4, 0, 0]],\n", " allow_flagging=\"never\",\n", ")\n", "try:\n", " demo.queue().launch(debug=True)\n", "except Exception:\n", " demo.queue().launch(debug=True, share=True)\n", "# if you are launching remotely, specify server_name and server_port\n", "# demo.launch(server_name='your server name', server_port='server port in int')\n", "# Read more in the docs: https://gradio.app/docs/" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "openvino_notebooks": { "imageUrl": "https://huggingface.co/stabilityai/stable-cascade/resolve/main/figures/collage_1.jpg", "tags": { "categories": [ "Model Demos", "AI Trends" ], "libraries": [], "other": [], "tasks": [ "Text-to-Image" ] } } }, "nbformat": 4, "nbformat_minor": 5 }