{ "cells": [ { "attachments": {}, "cell_type": "markdown", "id": "416bb194-5b2f-406d-a55c-d71dac6e93e1", "metadata": {}, "source": [ "# TripoSR feedforward 3D reconstruction from a single image and OpenVINO\n", "\n", "[TripoSR](https://huggingface.co/spaces/stabilityai/TripoSR) is a state-of-the-art open-source model for fast feedforward 3D reconstruction from a single image, developed in collaboration between [Tripo AI](https://www.tripo3d.ai/) and [Stability AI](https://stability.ai/news/triposr-3d-generation).\n", "\n", "You can find [the source code on GitHub](https://github.com/VAST-AI-Research/TripoSR) and [demo on HuggingFace](https://huggingface.co/spaces/stabilityai/TripoSR). Also, you can read the paper [TripoSR: Fast 3D Object Reconstruction from a Single Image](https://arxiv.org/abs/2403.02151).\n", "\n", "![Teaser Video](https://raw.githubusercontent.com/VAST-AI-Research/TripoSR/main/figures/teaser800.gif)\n", "\n", "#### Table of contents:\n", "\n", "- [Prerequisites](#Prerequisites)\n", "- [Get the original model](#Get-the-original-model)\n", "- [Convert the model to OpenVINO IR](#Convert-the-model-to-OpenVINO-IR)\n", "- [Compiling models and prepare pipeline](#Compiling-models-and-prepare-pipeline)\n", "- [Interactive inference](#Interactive-inference)\n" ] }, { "attachments": {}, "cell_type": "markdown", "id": "b32875d0-9935-42df-ba9b-c60cb31d9eff", "metadata": {}, "source": [ "## Prerequisites\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": null, "id": "6cb7f8ed-aa2d-4124-a77d-42ed3324ccfd", "metadata": { "collapsed": false }, "outputs": [], "source": [ "%pip install -q wheel setuptools pip --upgrade\n", "%pip install -q \"gradio>=4.19\" \"torch==2.2.2\" rembg trimesh einops \"omegaconf>=2.3.0\" \"transformers>=4.35.0\" \"openvino>=2024.0.0\" --extra-index-url https://download.pytorch.org/whl/cpu\n", "%pip install -q \"git+https://github.com/tatsy/torchmcubes.git\"" ] }, { "cell_type": "code", "execution_count": null, "id": "b78424b7-80e4-4470-9427-7286b9837566", "metadata": {}, "outputs": [], "source": [ "import sys\n", "from pathlib import Path\n", "\n", "if not Path(\"TripoSR\").exists():\n", " !git clone https://huggingface.co/spaces/stabilityai/TripoSR\n", "\n", "sys.path.append(\"TripoSR\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "92e3f21e-9a4f-4fe4-aa16-016b4a118d5f", "metadata": {}, "source": [ "## Get the original model" ] }, { "cell_type": "code", "execution_count": null, "id": "b7c990fa-0757-480c-8c42-d0abc50edc8e", "metadata": { "scrolled": true }, "outputs": [], "source": [ "import os\n", "\n", "from tsr.system import TSR\n", "\n", "\n", "model = TSR.from_pretrained(\n", " \"stabilityai/TripoSR\",\n", " config_name=\"config.yaml\",\n", " weight_name=\"model.ckpt\",\n", ")\n", "model.renderer.set_chunk_size(131072)\n", "model.to(\"cpu\")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f451a152-867c-46d0-b83f-475929ab0c92", "metadata": {}, "source": [ "### Convert the model to OpenVINO IR\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "cdaa3200-f293-47d3-a7ab-b631e66e42dc", "metadata": {}, "source": [ "Define the conversion function for PyTorch modules. We use `ov.convert_model` function to obtain OpenVINO Intermediate Representation object and `ov.save_model` function to save it as XML file." ] }, { "cell_type": "code", "execution_count": null, "id": "d37dbcee-8c40-4cf6-a1ed-5b4c2358f4d3", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "import openvino as ov\n", "\n", "\n", "def convert(model: torch.nn.Module, xml_path: str, example_input):\n", " xml_path = Path(xml_path)\n", " if not xml_path.exists():\n", " xml_path.parent.mkdir(parents=True, exist_ok=True)\n", " with torch.no_grad():\n", " converted_model = ov.convert_model(model, example_input=example_input)\n", " ov.save_model(converted_model, xml_path, compress_to_fp16=False)\n", "\n", " # cleanup memory\n", " torch._C._jit_clear_class_registry()\n", " torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()\n", " torch.jit._state._clear_class_state()" ] }, { "attachments": {}, "cell_type": "markdown", "id": "51d3ad7f-2c16-4967-811d-d68f788faa7c", "metadata": {}, "source": [ "The original model is a pipeline of several models. There are `image_tokenizer`, `tokenizer`, `backbone` and `post_processor`. `image_tokenizer` contains `ViTModel` that consists of `ViTPatchEmbeddings`, `ViTEncoder` and `ViTPooler`. `tokenizer` is `Triplane1DTokenizer`, `backbone` is `Transformer1D`, `post_processor` is `TriplaneUpsampleNetwork`. Convert all internal models one by one. " ] }, { "cell_type": "code", "execution_count": null, "id": "d9456c3c-64dd-4109-968c-8f51d42b9876", "metadata": {}, "outputs": [], "source": [ "VIT_PATCH_EMBEDDINGS_OV_PATH = Path(\"models/vit_patch_embeddings_ir.xml\")\n", "\n", "\n", "class PatchEmbedingWrapper(torch.nn.Module):\n", " def __init__(self, patch_embeddings):\n", " super().__init__()\n", " self.patch_embeddings = patch_embeddings\n", "\n", " def forward(self, pixel_values, interpolate_pos_encoding=True):\n", " outputs = self.patch_embeddings(pixel_values=pixel_values, interpolate_pos_encoding=True)\n", " return outputs\n", "\n", "\n", "example_input = {\n", " \"pixel_values\": torch.rand([1, 3, 512, 512], dtype=torch.float32),\n", "}\n", "\n", "convert(\n", " PatchEmbedingWrapper(model.image_tokenizer.model.embeddings.patch_embeddings),\n", " VIT_PATCH_EMBEDDINGS_OV_PATH,\n", " example_input,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "f31d0b34-9690-454c-b497-f2330d18bac4", "metadata": {}, "outputs": [], "source": [ "VIT_ENCODER_OV_PATH = Path(\"models/vit_encoder_ir.xml\")\n", "\n", "\n", "class EncoderWrapper(torch.nn.Module):\n", " def __init__(self, encoder):\n", " super().__init__()\n", " self.encoder = encoder\n", "\n", " def forward(\n", " self,\n", " hidden_states=None,\n", " head_mask=None,\n", " output_attentions=False,\n", " output_hidden_states=False,\n", " return_dict=False,\n", " ):\n", " outputs = self.encoder(\n", " hidden_states=hidden_states,\n", " )\n", "\n", " return outputs.last_hidden_state\n", "\n", "\n", "example_input = {\n", " \"hidden_states\": torch.rand([1, 1025, 768], dtype=torch.float32),\n", "}\n", "\n", "convert(\n", " EncoderWrapper(model.image_tokenizer.model.encoder),\n", " VIT_ENCODER_OV_PATH,\n", " example_input,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "07815ea4-f851-4c13-99b9-1b905ef998cf", "metadata": {}, "outputs": [], "source": [ "VIT_POOLER_OV_PATH = Path(\"models/vit_pooler_ir.xml\")\n", "convert(\n", " model.image_tokenizer.model.pooler,\n", " VIT_POOLER_OV_PATH,\n", " torch.rand([1, 1025, 768], dtype=torch.float32),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "465dc1db-5ce1-428f-bc90-740aa6d0dca5", "metadata": {}, "outputs": [], "source": [ "TOKENIZER_OV_PATH = Path(\"models/tokenizer_ir.xml\")\n", "convert(model.tokenizer, TOKENIZER_OV_PATH, torch.tensor(1))" ] }, { "cell_type": "code", "execution_count": null, "id": "4b7cdbc5-ed39-49b4-922d-95603478209f", "metadata": {}, "outputs": [], "source": [ "example_input = {\n", " \"hidden_states\": torch.rand([1, 1024, 3072], dtype=torch.float32),\n", " \"encoder_hidden_states\": torch.rand([1, 1025, 768], dtype=torch.float32),\n", "}\n", "\n", "BACKBONE_OV_PATH = Path(\"models/backbone_ir.xml\")\n", "convert(model.backbone, BACKBONE_OV_PATH, example_input)" ] }, { "cell_type": "code", "execution_count": null, "id": "b209191b-5566-48ad-adb4-0ceee5d4e469", "metadata": {}, "outputs": [], "source": [ "POST_PROCESSOR_OV_PATH = Path(\"models/post_processor_ir.xml\")\n", "convert(\n", " model.post_processor,\n", " POST_PROCESSOR_OV_PATH,\n", " torch.rand([1, 3, 1024, 32, 32], dtype=torch.float32),\n", ")" ] }, { "attachments": {}, "cell_type": "markdown", "id": "70b16770-acb7-4268-b135-4e3c0d7b1dd5", "metadata": {}, "source": [ "## Compiling models and prepare pipeline\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "221d88fe-7a91-4c39-a7e8-9e1f6eee0bc4", "metadata": {}, "source": [ "Select device from dropdown list for running inference using OpenVINO." ] }, { "cell_type": "code", "execution_count": null, "id": "9324143f-02b5-4238-a173-4a649d9ab1d5", "metadata": {}, "outputs": [], "source": [ "import ipywidgets as widgets\n", "\n", "\n", "core = ov.Core()\n", "device = widgets.Dropdown(\n", " options=core.available_devices + [\"AUTO\"],\n", " value=\"AUTO\",\n", " description=\"Device:\",\n", " disabled=False,\n", ")\n", "\n", "device" ] }, { "cell_type": "code", "execution_count": null, "id": "379a2cb2-6086-47d9-ae83-be7f74a5c14f", "metadata": {}, "outputs": [], "source": [ "compiled_vit_patch_embeddings = core.compile_model(VIT_PATCH_EMBEDDINGS_OV_PATH, device.value)\n", "compiled_vit_model_encoder = core.compile_model(VIT_ENCODER_OV_PATH, device.value)\n", "compiled_vit_model_pooler = core.compile_model(VIT_POOLER_OV_PATH, device.value)\n", "\n", "compiled_tokenizer = core.compile_model(TOKENIZER_OV_PATH, device.value)\n", "compiled_backbone = core.compile_model(BACKBONE_OV_PATH, device.value)\n", "compiled_post_processor = core.compile_model(POST_PROCESSOR_OV_PATH, device.value)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "beacf131-aa84-4a0b-9d62-4df35d9fe5a4", "metadata": {}, "source": [ "Let's create callable wrapper classes for compiled models to allow interaction with original `TSR` class. Note that all of wrapper classes return `torch.Tensor`s instead of `np.array`s." ] }, { "cell_type": "code", "execution_count": null, "id": "7f97d859-4e2d-4277-8816-7ca8d6c76778", "metadata": {}, "outputs": [], "source": [ "from collections import namedtuple\n", "\n", "\n", "class VitPatchEmdeddingsWrapper(torch.nn.Module):\n", " def __init__(self, vit_patch_embeddings, model):\n", " super().__init__()\n", " self.vit_patch_embeddings = vit_patch_embeddings\n", " self.projection = model.projection\n", "\n", " def forward(self, pixel_values, interpolate_pos_encoding=False):\n", " inputs = {\n", " \"pixel_values\": pixel_values,\n", " }\n", " outs = self.vit_patch_embeddings(inputs)[0]\n", "\n", " return torch.from_numpy(outs)\n", "\n", "\n", "class VitModelEncoderWrapper(torch.nn.Module):\n", " def __init__(self, vit_model_encoder):\n", " super().__init__()\n", " self.vit_model_encoder = vit_model_encoder\n", "\n", " def forward(\n", " self,\n", " hidden_states,\n", " head_mask,\n", " output_attentions=False,\n", " output_hidden_states=False,\n", " return_dict=False,\n", " ):\n", " inputs = {\n", " \"hidden_states\": hidden_states.detach().numpy(),\n", " }\n", "\n", " outs = self.vit_model_encoder(inputs)\n", " outputs = namedtuple(\"BaseModelOutput\", (\"last_hidden_state\", \"hidden_states\", \"attentions\"))\n", "\n", " return outputs(torch.from_numpy(outs[0]), None, None)\n", "\n", "\n", "class VitModelPoolerWrapper(torch.nn.Module):\n", " def __init__(self, vit_model_pooler):\n", " super().__init__()\n", " self.vit_model_pooler = vit_model_pooler\n", "\n", " def forward(self, hidden_states):\n", " outs = self.vit_model_pooler(hidden_states.detach().numpy())[0]\n", "\n", " return torch.from_numpy(outs)\n", "\n", "\n", "class TokenizerWrapper(torch.nn.Module):\n", " def __init__(self, tokenizer, model):\n", " super().__init__()\n", " self.tokenizer = tokenizer\n", " self.detokenize = model.detokenize\n", "\n", " def forward(self, batch_size):\n", " outs = self.tokenizer(batch_size)[0]\n", "\n", " return torch.from_numpy(outs)\n", "\n", "\n", "class BackboneWrapper(torch.nn.Module):\n", " def __init__(self, backbone):\n", " super().__init__()\n", " self.backbone = backbone\n", "\n", " def forward(self, hidden_states, encoder_hidden_states):\n", " inputs = {\n", " \"hidden_states\": hidden_states,\n", " \"encoder_hidden_states\": encoder_hidden_states.detach().numpy(),\n", " }\n", "\n", " outs = self.backbone(inputs)[0]\n", "\n", " return torch.from_numpy(outs)\n", "\n", "\n", "class PostProcessorWrapper(torch.nn.Module):\n", " def __init__(self, post_processor):\n", " super().__init__()\n", " self.post_processor = post_processor\n", "\n", " def forward(self, triplanes):\n", " outs = self.post_processor(triplanes)[0]\n", "\n", " return torch.from_numpy(outs)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "6e18c61a-5a47-4ccb-a7dc-5b82caa1bf7d", "metadata": {}, "source": [ "Replace all models in the original model by wrappers instances:" ] }, { "cell_type": "code", "execution_count": null, "id": "04ada5f4-6a4a-418d-8227-379b76e3faf1", "metadata": {}, "outputs": [], "source": [ "model.image_tokenizer.model.embeddings.patch_embeddings = VitPatchEmdeddingsWrapper(\n", " compiled_vit_patch_embeddings,\n", " model.image_tokenizer.model.embeddings.patch_embeddings,\n", ")\n", "model.image_tokenizer.model.encoder = VitModelEncoderWrapper(compiled_vit_model_encoder)\n", "model.image_tokenizer.model.pooler = VitModelPoolerWrapper(compiled_vit_model_pooler)\n", "\n", "model.tokenizer = TokenizerWrapper(compiled_tokenizer, model.tokenizer)\n", "model.backbone = BackboneWrapper(compiled_backbone)\n", "model.post_processor = PostProcessorWrapper(compiled_post_processor)" ] }, { "attachments": {}, "cell_type": "markdown", "id": "f032d9d5-5ee7-4fad-b1b8-e6b494bf840e", "metadata": {}, "source": [ "## Interactive inference\n", "[back to top ⬆️](#Table-of-contents:)" ] }, { "cell_type": "code", "execution_count": null, "id": "044bc9cf-3141-433d-a267-f5da84c9aa11", "metadata": {}, "outputs": [], "source": [ "import tempfile\n", "\n", "import gradio as gr\n", "import numpy as np\n", "import rembg\n", "from PIL import Image\n", "\n", "from tsr.utils import remove_background, resize_foreground, to_gradio_3d_orientation\n", "\n", "\n", "rembg_session = rembg.new_session()\n", "\n", "\n", "def check_input_image(input_image):\n", " if input_image is None:\n", " raise gr.Error(\"No image uploaded!\")\n", "\n", "\n", "def preprocess(input_image, do_remove_background, foreground_ratio):\n", " def fill_background(image):\n", " image = np.array(image).astype(np.float32) / 255.0\n", " image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5\n", " image = Image.fromarray((image * 255.0).astype(np.uint8))\n", " return image\n", "\n", " if do_remove_background:\n", " image = input_image.convert(\"RGB\")\n", " image = remove_background(image, rembg_session)\n", " image = resize_foreground(image, foreground_ratio)\n", " image = fill_background(image)\n", " else:\n", " image = input_image\n", " if image.mode == \"RGBA\":\n", " image = fill_background(image)\n", " return image\n", "\n", "\n", "def generate(image):\n", " scene_codes = model(image, \"cpu\") # the device is provided for the image processor\n", " mesh = model.extract_mesh(scene_codes)[0]\n", " mesh = to_gradio_3d_orientation(mesh)\n", " mesh_path = tempfile.NamedTemporaryFile(suffix=\".obj\", delete=False)\n", " mesh.export(mesh_path.name)\n", " return mesh_path.name\n", "\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row(variant=\"panel\"):\n", " with gr.Column():\n", " with gr.Row():\n", " input_image = gr.Image(\n", " label=\"Input Image\",\n", " image_mode=\"RGBA\",\n", " sources=\"upload\",\n", " type=\"pil\",\n", " elem_id=\"content_image\",\n", " )\n", " processed_image = gr.Image(label=\"Processed Image\", interactive=False)\n", " with gr.Row():\n", " with gr.Group():\n", " do_remove_background = gr.Checkbox(label=\"Remove Background\", value=True)\n", " foreground_ratio = gr.Slider(\n", " label=\"Foreground Ratio\",\n", " minimum=0.5,\n", " maximum=1.0,\n", " value=0.85,\n", " step=0.05,\n", " )\n", " with gr.Row():\n", " submit = gr.Button(\"Generate\", elem_id=\"generate\", variant=\"primary\")\n", " with gr.Column():\n", " with gr.Tab(\"Model\"):\n", " output_model = gr.Model3D(\n", " label=\"Output Model\",\n", " interactive=False,\n", " )\n", " with gr.Row(variant=\"panel\"):\n", " gr.Examples(\n", " examples=[os.path.join(\"TripoSR/examples\", img_name) for img_name in sorted(os.listdir(\"TripoSR/examples\"))],\n", " inputs=[input_image],\n", " outputs=[processed_image, output_model],\n", " label=\"Examples\",\n", " examples_per_page=20,\n", " )\n", " submit.click(fn=check_input_image, inputs=[input_image]).success(\n", " fn=preprocess,\n", " inputs=[input_image, do_remove_background, foreground_ratio],\n", " outputs=[processed_image],\n", " ).success(\n", " fn=generate,\n", " inputs=[processed_image],\n", " outputs=[output_model],\n", " )\n", "\n", "try:\n", " demo.launch(debug=True, height=680)\n", "except Exception:\n", " demo.queue().launch(share=True, debug=True, height=680)\n", "# if you are launching remotely, specify server_name and server_port\n", "# demo.launch(server_name='your server name', server_port='server port in int')\n", "# Read more in the docs: https://gradio.app/docs/" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" }, "openvino_notebooks": { "imageUrl": "https://github.com/VAST-AI-Research/TripoSR/blob/main/figures/teaser800.gif?raw=true", "tags": { "categories": [ "Model Demos", "AI Trends" ], "libraries": [], "other": [], "tasks": [ "Image-to-3D" ] } } }, "nbformat": 4, "nbformat_minor": 5 }