diff --git "a/DenseAV/demo.ipynb" "b/DenseAV/demo.ipynb" new file mode 100644--- /dev/null +++ "b/DenseAV/demo.ipynb" @@ -0,0 +1,757 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# DenseAV Demonstration Notebook\n", + "\n", + "> ⚠️ Change your collab runtime to T4 GPU before running this notebook\n", + "\n", + "In this notebook we will walk through how to load, visualize, and work with our catalog of pre-trained models." + ], + "metadata": { + "collapsed": false, + "id": "c413e5bb192c72eb" + }, + "id": "c413e5bb192c72eb" + }, + { + "cell_type": "markdown", + "source": [ + "## Set up Google Collab\n", + "> ⚠️ Skip this section if you are not on Google Collab\n" + ], + "metadata": { + "collapsed": false, + "id": "7c65e267ad0b57b2" + }, + "id": "7c65e267ad0b57b2" + }, + { + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "fatal: destination path 'DenseAV' already exists and is not an empty directory.\n" + ] + } + ], + "source": [ + "!git clone https://github.com/mhamilton723/DenseAV" + ], + "metadata": { + "id": "8e0c798342f1699", + "outputId": "a04482a0-f368-48b3-8645-85d5602a7bec", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "id": "8e0c798342f1699", + "execution_count": 1 + }, + { + "cell_type": "code", + "source": [ + "!pip install av" + ], + "metadata": { + "id": "wXbCdwNkk4zF", + "outputId": "089bdd3c-9501-461e-91ab-4b33e480637f", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "id": "wXbCdwNkk4zF", + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Requirement already satisfied: av in /usr/local/lib/python3.10/dist-packages (12.1.0)\n" + ] + } + ] + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "import os\n", + "os.chdir(\"DenseAV/\")" + ], + "metadata": { + "id": "397cf48fa3832a2b" + }, + "id": "397cf48fa3832a2b", + "execution_count": 3 + }, + { + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Obtaining file:///content/DenseAV\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (2.3.0+cu121)\n", + "Requirement already satisfied: kornia in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (0.7.2)\n", + "Requirement already satisfied: omegaconf in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (2.3.0)\n", + "Requirement already satisfied: pytorch-lightning in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (2.2.5)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (0.18.0+cu121)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (4.66.4)\n", + "Requirement already satisfied: torchmetrics in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (1.4.0.post0)\n", + "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (1.2.2)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (1.25.2)\n", + "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (3.7.1)\n", + "Requirement already satisfied: timm==0.4.12 in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (0.4.12)\n", + "Requirement already satisfied: moviepy in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (1.0.3)\n", + "Requirement already satisfied: hydra-core in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (1.3.2)\n", + "Requirement already satisfied: peft==0.5.0 in /usr/local/lib/python3.10/dist-packages (from denseav==0.1.0) (0.5.0)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft==0.5.0->denseav==0.1.0) (24.0)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft==0.5.0->denseav==0.1.0) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft==0.5.0->denseav==0.1.0) (6.0.1)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft==0.5.0->denseav==0.1.0) (4.41.2)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from peft==0.5.0->denseav==0.1.0) (0.30.1)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft==0.5.0->denseav==0.1.0) (0.4.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (3.14.0)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (4.12.1)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (1.12.1)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (3.3)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (2023.6.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (2.20.5)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (12.1.105)\n", + "Requirement already satisfied: triton==2.3.0 in /usr/local/lib/python3.10/dist-packages (from torch->denseav==0.1.0) (2.3.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->denseav==0.1.0) (12.5.40)\n", + "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.10/dist-packages (from hydra-core->denseav==0.1.0) (4.9.3)\n", + "Requirement already satisfied: kornia-rs>=0.1.0 in /usr/local/lib/python3.10/dist-packages (from kornia->denseav==0.1.0) (0.1.3)\n", + "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->denseav==0.1.0) (1.2.1)\n", + "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->denseav==0.1.0) (0.12.1)\n", + "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->denseav==0.1.0) (4.53.0)\n", + "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->denseav==0.1.0) (1.4.5)\n", + "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->denseav==0.1.0) (9.4.0)\n", + "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->denseav==0.1.0) (3.1.2)\n", + "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->denseav==0.1.0) (2.8.2)\n", + "Requirement already satisfied: decorator<5.0,>=4.0.2 in /usr/local/lib/python3.10/dist-packages (from moviepy->denseav==0.1.0) (4.4.2)\n", + "Requirement already satisfied: requests<3.0,>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from moviepy->denseav==0.1.0) (2.31.0)\n", + "Requirement already satisfied: proglog<=1.0.0 in /usr/local/lib/python3.10/dist-packages (from moviepy->denseav==0.1.0) (0.1.10)\n", + "Requirement already satisfied: imageio<3.0,>=2.5 in /usr/local/lib/python3.10/dist-packages (from moviepy->denseav==0.1.0) (2.31.6)\n", + "Requirement already satisfied: imageio-ffmpeg>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from moviepy->denseav==0.1.0) (0.5.1)\n", + "Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from pytorch-lightning->denseav==0.1.0) (0.11.2)\n", + "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->denseav==0.1.0) (1.11.4)\n", + "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->denseav==0.1.0) (1.4.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->denseav==0.1.0) (3.5.0)\n", + "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /usr/local/lib/python3.10/dist-packages (from fsspec->torch->denseav==0.1.0) (3.9.5)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from imageio-ffmpeg>=0.2.0->moviepy->denseav==0.1.0) (67.7.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->denseav==0.1.0) (1.16.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy->denseav==0.1.0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy->denseav==0.1.0) (3.7)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy->denseav==0.1.0) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3.0,>=2.8.1->moviepy->denseav==0.1.0) (2024.6.2)\n", + "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (from accelerate->peft==0.5.0->denseav==0.1.0) (0.23.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch->denseav==0.1.0) (2.1.5)\n", + "Requirement already satisfied: mpmath<1.4.0,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->denseav==0.1.0) (1.3.0)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft==0.5.0->denseav==0.1.0) (2024.5.15)\n", + "Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers->peft==0.5.0->denseav==0.1.0) (0.19.1)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch->denseav==0.1.0) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch->denseav==0.1.0) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch->denseav==0.1.0) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch->denseav==0.1.0) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch->denseav==0.1.0) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec->torch->denseav==0.1.0) (4.0.3)\n", + "Installing collected packages: denseav\n", + " Attempting uninstall: denseav\n", + " Found existing installation: denseav 0.1.0\n", + " Uninstalling denseav-0.1.0:\n", + " Successfully uninstalled denseav-0.1.0\n", + " Running setup.py develop for denseav\n", + "Successfully installed denseav-0.1.0\n" + ] + } + ], + "source": [ + "!pip install -e ." + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:03:20.413866Z", + "start_time": "2024-06-06T17:03:20.296186Z" + }, + "id": "19d3129b03459c94", + "outputId": "aaa2790c-8855-4be3-f993-9a6dbb7aa525", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "id": "19d3129b03459c94", + "execution_count": 4 + }, + { + "cell_type": "markdown", + "source": [ + "## Import dependencies and load a pretrained DenseAV Model\n" + ], + "metadata": { + "collapsed": false, + "id": "800b72c026c98194" + }, + "id": "800b72c026c98194" + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2024-06-06T17:07:27.801018Z", + "start_time": "2024-06-06T17:07:24.055483Z" + }, + "id": "initial_id" + }, + "outputs": [], + "source": [ + "from os.path import join\n", + "\n", + "import torch\n", + "import torchvision\n", + "import torchvision.transforms as T\n", + "from PIL import Image\n", + "from torchaudio.functional import resample\n", + "\n", + "from denseav.plotting import plot_attention_video, plot_2head_attention_video, plot_feature_video, display_video_in_notebook\n", + "from denseav.shared import norm, crop_to_divisor, blur_dim" + ] + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "model_name = \"sound_and_language\"\n", + "video_path = \"samples/puppies.mp4\"\n", + "result_dir = \"results\"\n", + "load_size = 224\n", + "plot_size = 224" + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:07:27.806669Z", + "start_time": "2024-06-06T17:07:27.803188Z" + }, + "id": "e0de70a3865c7239" + }, + "id": "e0de70a3865c7239", + "execution_count": 6 + }, + { + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "Using cache found in /root/.cache/torch/hub/mhamilton723_DenseAV_main\n", + "INFO:pytorch_lightning.utilities.migration.utils:Lightning automatically upgraded your loaded checkpoint from v1.9.4 to v2.2.5. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint https:/marhamilresearch4.blob.core.windows.net/denseav-public/hub/denseav_2head.ckpt`\n", + "Using cache found in /root/.cache/torch/hub/facebookresearch_dino_main\n", + "WARNING:py.warnings:/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n", + "\n", + "Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "Some weights of HubertModel were not initialized from the model checkpoint at facebook/hubert-large-ls960-ft and are newly initialized: ['hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'hubert.encoder.pos_conv_embed.conv.parametrizations.weight.original1']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "trainable params: 147,456 || all params: 21,817,728 || trainable%: 0.6758540577644016\n" + ] + } + ], + "source": [ + "model = torch.hub.load('mhamilton723/DenseAV', model_name).cuda()" + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:07:37.721422Z", + "start_time": "2024-06-06T17:07:27.808035Z" + }, + "id": "e35605083dbeeb1d", + "outputId": "8f573a06-33c2-40b5-870e-5125a9e30ac2", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "id": "e35605083dbeeb1d", + "execution_count": 7 + }, + { + "cell_type": "markdown", + "source": [ + "## Load a sample video and prepare it for DenseAV" + ], + "metadata": { + "collapsed": false, + "id": "742cfc52ee8d0aad" + }, + "id": "742cfc52ee8d0aad" + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "original_frames, audio, info = torchvision.io.read_video(video_path, pts_unit='sec')\n", + "sample_rate = 16000\n", + "\n", + "if info[\"audio_fps\"] != sample_rate:\n", + " audio = resample(audio, info[\"audio_fps\"], sample_rate)\n", + "audio = audio[0].unsqueeze(0)\n", + "\n", + "img_transform = T.Compose([\n", + " T.Resize(load_size, Image.BILINEAR),\n", + " lambda x: crop_to_divisor(x, 8),\n", + " lambda x: x.to(torch.float32) / 255,\n", + " norm])\n", + "\n", + "frames = torch.cat([img_transform(f.permute(2, 0, 1)).unsqueeze(0) for f in original_frames], axis=0)\n", + "\n", + "plotting_img_transform = T.Compose([\n", + " T.Resize(plot_size, Image.BILINEAR),\n", + " lambda x: crop_to_divisor(x, 8),\n", + " lambda x: x.to(torch.float32) / 255])\n", + "\n", + "frames_to_plot = plotting_img_transform(original_frames.permute(0, 3, 1, 2))" + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:07:40.993012Z", + "start_time": "2024-06-06T17:07:37.724341Z" + }, + "id": "2d5b8553fc0e372" + }, + "id": "2d5b8553fc0e372", + "execution_count": 8 + }, + { + "cell_type": "markdown", + "source": [ + "## Use DenseAV to obtain dense AV-aligned features" + ], + "metadata": { + "collapsed": false, + "id": "203ebe0f66dde1" + }, + "id": "203ebe0f66dde1" + }, + { + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([181, 2, 14, 14, 33])\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " audio_feats = model.forward_audio({\"audio\": audio.cuda()})\n", + " audio_feats = {k: v.cpu() for k,v in audio_feats.items()}\n", + " image_feats = model.forward_image({\"frames\": frames.unsqueeze(0).cuda()}, max_batch_size=2)\n", + " image_feats = {k: v.cpu() for k,v in image_feats.items()}\n", + "\n", + "\n", + " sim_by_head = model.sim_agg.get_pairwise_sims(\n", + " {**image_feats, **audio_feats},\n", + " raw=False,\n", + " agg_sim=False,\n", + " agg_heads=False\n", + " ).mean(dim=-2).cpu()\n", + "\n", + " sim_by_head = blur_dim(sim_by_head, window=3, dim=-1)\n", + " print(sim_by_head.shape)" + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:07:51.348730Z", + "start_time": "2024-06-06T17:07:40.995122Z" + }, + "id": "a26feec6533ad7ec", + "outputId": "c58def18-cb9b-4dc0-f101-1d375e8e5d10", + "colab": { + "base_uri": "https://localhost:8080/" + } + }, + "id": "a26feec6533ad7ec", + "execution_count": 9 + }, + { + "cell_type": "markdown", + "source": [ + "## Visualize Cross-Modal Attention" + ], + "metadata": { + "collapsed": false, + "id": "719b17171b1d9703" + }, + "id": "719b17171b1d9703" + }, + { + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Moviepy - Building video results/attention.mp4.\n", + "MoviePy - Writing audio in attentionTEMP_MPY_wvf_snd.mp3\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "MoviePy - Done.\n", + "Moviepy - Writing video results/attention.mp4\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Moviepy - Done !\n", + "Moviepy - video ready results/attention.mp4\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "plot_attention_video(\n", + " sim_by_head,\n", + " frames_to_plot,\n", + " audio,\n", + " info[\"video_fps\"],\n", + " sample_rate,\n", + " \"results/attention.mp4\")\n", + "display_video_in_notebook(\"results/attention.mp4\")" + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:08:03.781030Z", + "start_time": "2024-06-06T17:07:51.350768Z" + }, + "id": "99c46e5f3a50de3c", + "outputId": "04a4b558-2816-4fdb-c5fc-f060f1cd1057", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 543 + } + }, + "id": "99c46e5f3a50de3c", + "execution_count": 28 + }, + { + "cell_type": "markdown", + "source": [ + "## Visualize Cross Modal Attention by Head to Disentangle Sound and Language" + ], + "metadata": { + "collapsed": false, + "id": "a3f4d96cce322692" + }, + "id": "a3f4d96cce322692" + }, + { + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Moviepy - Building video results/2head_attention.mp4.\n", + "MoviePy - Writing audio in 2head_attentionTEMP_MPY_wvf_snd.mp3\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "MoviePy - Done.\n", + "Moviepy - Writing video results/2head_attention.mp4\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Moviepy - Done !\n", + "Moviepy - video ready results/2head_attention.mp4\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "if model_name == \"sound_and_language\":\n", + " plot_2head_attention_video(\n", + " sim_by_head,\n", + " frames_to_plot,\n", + " audio,\n", + " info[\"video_fps\"],\n", + " sample_rate,\n", + " \"results/2head_attention.mp4\")\n", + " display_video_in_notebook(\"results/2head_attention.mp4\")" + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:08:21.087052Z", + "start_time": "2024-06-06T17:08:03.782840Z" + }, + "id": "91d0eec42a35de9b", + "outputId": "bad4bf64-9258-49a7-e0d8-72e1f2c5d404", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 543 + } + }, + "id": "91d0eec42a35de9b", + "execution_count": 29 + }, + { + "cell_type": "markdown", + "source": [ + "## Plot Deep Features" + ], + "metadata": { + "collapsed": false, + "id": "9a886cfeaf91e0ec" + }, + "id": "9a886cfeaf91e0ec" + }, + { + "cell_type": "code", + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Moviepy - Building video results/visual_features.mp4.\n", + "MoviePy - Writing audio in visual_featuresTEMP_MPY_wvf_snd.mp3\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "MoviePy - Done.\n", + "Moviepy - Writing video results/visual_features.mp4\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Moviepy - Done !\n", + "Moviepy - video ready results/visual_features.mp4\n", + "Moviepy - Building video results/audio_features.mp4.\n", + "MoviePy - Writing audio in audio_featuresTEMP_MPY_wvf_snd.mp3\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "MoviePy - Done.\n", + "Moviepy - Writing video results/audio_features.mp4\n", + "\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Moviepy - Done !\n", + "Moviepy - video ready results/audio_features.mp4\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + " \n", + " " + ] + }, + "metadata": {} + } + ], + "source": [ + "plot_feature_video(\n", + " image_feats[\"image_feats\"].cpu(),\n", + " audio_feats['audio_feats'].cpu(),\n", + " frames_to_plot,\n", + " audio,\n", + " info[\"video_fps\"],\n", + " sample_rate,\n", + " \"results/visual_features.mp4\",\n", + " \"results/audio_features.mp4\",\n", + ")\n", + "display_video_in_notebook(\"results/visual_features.mp4\")\n", + "display_video_in_notebook(\"results/audio_features.mp4\")" + ], + "metadata": { + "ExecuteTime": { + "end_time": "2024-06-06T17:08:30.187416Z", + "start_time": "2024-06-06T17:08:21.090287Z" + }, + "id": "d244fec7aaa340cd", + "outputId": "c5dc4a3d-1867-4c06-dea1-c7ae49ab3607", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 779 + } + }, + "id": "d244fec7aaa340cd", + "execution_count": 30 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + }, + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "accelerator": "GPU" + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file