{ "cells": [ { "cell_type": "markdown", "id": "55df6d0d-71cf-4110-81ed-7c0d3ce58e43", "metadata": {}, "source": [ "## Import" ] }, { "cell_type": "code", "execution_count": 1, "id": "0abe9574-05f7-4684-b586-033827b89c32", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "74e70729-b658-4ffd-9d8b-ae42a2d1b212", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import numpy as np\n", "from fairseq import utils, tasks\n", "from fairseq import checkpoint_utils\n", "from utils.eval_utils import eval_step\n", "from tasks.mm_tasks.caption import CaptionTask\n", "from models.unival import UnIVALModel\n", "from PIL import Image\n", "\n", "import random\n", "from torchvision.transforms import functional as F\n", "from torchvision.transforms import InterpolationMode\n", "\n", "from matplotlib import pyplot as plt\n", "\n", "# turn on cuda if GPU is available\n", "use_cuda = torch.cuda.is_available()\n", "# use fp16 only when GPU is available\n", "use_fp16 = False" ] }, { "cell_type": "code", "execution_count": 3, "id": "ce03a870-2852-410e-97c4-59461d08f60a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ ".register_task_cls(cls)>" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Register refcoco task\n", "tasks.register_task('audio_caption', CaptionTask)" ] }, { "cell_type": "markdown", "id": "58361680-3e90-4fff-962e-2ff67c1e7289", "metadata": {}, "source": [ "### Load model" ] }, { "cell_type": "code", "execution_count": 64, "id": "adb79611-7563-4fb6-a576-f31050f8438e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "self.sample_patch_num 784\n", "self.sample_audio_patch_num None\n", "self.sample_video_patch_num None\n", "self.with_cls False\n", "Loading: all_resnext101\n", "use bn: \n", "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n", "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n", "Loading: pann_cnn14\n", "load pretrained_model /data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth\n", "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc1.weight', 'fc1.bias', 'fc_audioset.weight', 'fc_audioset.bias'])\n", "load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth\n", "\n", "unival\n", "task\n", "getattr(args, \"stop_on_max_len\", False) False\n" ] } ], "source": [ "# Load pretrained ckpt & config\n", "\n", "checkpoint_path = '/data/mshukor/logs/ofa/best_models/unival_audio_caption/checkpoint_best.pt'\n", "\n", "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n", "audio_model_path = '/data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth'\n", "resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n", "\n", "\n", "\n", "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":22, \"no_repeat_ngram_size\":3, \"seed\":7, \"unnormalized\": False,\n", " \"bpe_dir\":\"utils/BPE\", \"video_model_path\": video_model_path, \"audio_model_path\": audio_model_path, \"resnet_model_path\": resnet_model_path,}\n", "\n", "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n", " utils.split_paths(checkpoint_path),\n", " arg_overrides=overrides\n", " )\n", "\n", "# Move models to GPU\n", "for model in models:\n", " model.eval()\n", " if use_fp16:\n", " model.half()\n", " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n", " model.cuda()\n", " model.prepare_for_inference_(cfg)\n", "\n", "# Initialize generator\n", "generator = task.build_generator(models, cfg.generation)" ] }, { "cell_type": "markdown", "id": "e79aad39-1424-47d5-8cd4-6ab77ea46fb4", "metadata": {}, "source": [ "### Preprocess" ] }, { "cell_type": "code", "execution_count": 65, "id": "576a3e84-a6aa-446d-adab-fef9499318fc", "metadata": {}, "outputs": [], "source": [ "# Image transform\n", "from torchvision import transforms\n", "import torchaudio\n", "\n", "from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG\n", "\n", "\n", "mean = [0.5, 0.5, 0.5]\n", "std = [0.5, 0.5, 0.5]\n", "\n", "\n", "\n", "def process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG):\n", "\n", " # audio \n", " data_path = audio_path\n", "\n", "\n", "\n", " audio_data, orig_sr = torchaudio.load(data_path)\n", " audio_data = torchaudio.transforms.Resample(orig_sr, sample_rate)(audio_data[0])\n", "\n", " sample = {}\n", "\n", " sample = get_audio_features(\n", " sample, audio_data, max_audio_len, \n", " data_truncating='rand_trunc', \n", " data_filling='repeatpad',\n", " audio_cfg=audio_cfg\n", " )\n", "\n", "\n", " waveform = sample['waveform']\n", " patch_audio = waveform\n", " \n", " return patch_audio.unsqueeze(0)\n", "\n", " \n", " \n", " \n", "# Text preprocess\n", "bos_item = torch.LongTensor([task.src_dict.bos()])\n", "eos_item = torch.LongTensor([task.src_dict.eos()])\n", "pad_idx = task.src_dict.pad()\n", "def encode_text(text, length=None, append_bos=False, append_eos=False):\n", " s = task.tgt_dict.encode_line(\n", " line=task.bpe.encode(text),\n", " add_if_not_exist=False,\n", " append_eos=False\n", " ).long()\n", " if length is not None:\n", " s = s[:length]\n", " if append_bos:\n", " s = torch.cat([bos_item, s])\n", " if append_eos:\n", " s = torch.cat([s, eos_item])\n", " return s\n", "\n", "\n", "# Construct input for caption task\n", "def construct_sample(audio_path):\n", " \n", " \n", " patch_audio = process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG)\n", " patch_image = torch.zeros((3, cfg.task.patch_image_size, cfg.task.patch_image_size)) \n", " \n", " patch_type = torch.tensor([2])\n", " patch_mask = torch.tensor([True])\n", " src_text = encode_text(\" what does the image describe?\", append_bos=True, append_eos=True).unsqueeze(0)\n", " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n", " sample = {\n", " \"id\":np.array(['42']),\n", " \"net_input\": {\n", " \"src_tokens\": src_text,\n", " \"src_lengths\": src_length,\n", " \"patch_images\": patch_image,\n", " \"patch_audios\": patch_audio,\n", " \"patch_masks\": patch_mask,\n", " \"patch_types\": patch_type,\n", " }\n", " }\n", " return sample\n", " \n", "# Function to turn FP32 to FP16\n", "def apply_half(t):\n", " if t.dtype is torch.float32:\n", " return t.to(dtype=torch.half)\n", " return t" ] }, { "cell_type": "markdown", "id": "f96f776e-9aa0-4271-b881-311851cc033c", "metadata": {}, "source": [ "### Inference" ] }, { "cell_type": "code", "execution_count": 114, "id": "6f8ddf8c-82e2-411c-baa3-850da02f1996", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 480000])\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "execution_count": 114, "metadata": {}, "output_type": "execute_result" } ], "source": [ "save_dir = '/home/mshukor/ofa_adastra'\n", "\n", "\n", "\n", "audio_path = '/data/mshukor/data/audiocaps/test/KSHpYhuTotY.wav' # A man talks while bees fly\n", "# audio_path = '/data/mshukor/data/audiocaps/test/6cS0FsUM-cQ.wav' # A cat is meowing and a man is speaking\n", "# audio_path = '/data/mshukor/data/audiocaps/test/6CDl4CqOgMg.wav' # A dog pants and whimpers\n", "# audio_path = '/data/mshukor/data/audiocaps/test/_BSmz3SEW1w.wav' # Pigeons coo and flap their wings\n", "# audio_path = '/data/mshukor/data/audiocaps/test/ZsTZ7jqbd9M.wav' # A man speaking with birds chirping in the background\n", "# audio_path = '/data/mshukor/data/audiocaps/test/5OM3tJh51pE.wav' # A woman giving a speech\n", "\n", "# audio_path = '/data/mshukor/data/audiocaps/test/AJtNitYMa1I.wav' # Food sizzling in a pan\n", "\n", "audio_path = '/data/mshukor/data/audiocaps/test/3MoF8myFs8Y.wav' # Wind blows hard and waves crash against a shoreline\n", "audio_path = '/data/mshukor/data/audiocaps/test/350OCezayrk.wav' # A motor vehicle engine is idling and vibrating\n", "\n", "\n", "## limitations\n", "# audio_path = '/data/mshukor/data/audiocaps/test/EBCH7TPgiPc.wav' # A motor vehicle engine is running and revving and an adult male speaks in the background\n", "\n", "\n", "sample = construct_sample(audio_path)\n", "sample = utils.move_to_cuda(sample) if use_cuda else sample\n", "sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n", "print(sample['net_input']['patch_audios'].shape)\n" ] }, { "cell_type": "code", "execution_count": 115, "id": "4651039c-b8c0-4687-871e-b42cb13b2984", "metadata": {}, "outputs": [], "source": [ "from utils.eval_utils import eval_caption\n", "\n", "with torch.no_grad():\n", " result, scores = eval_caption(task, generator, models, sample)" ] }, { "cell_type": "code", "execution_count": 116, "id": "712150d4-f28c-4538-870f-b33f775725d5", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A motor vehicle engine is idling and vibrating\n" ] } ], "source": [ "caption = result[0]['caption']\n", "print(caption)\n", "\n", "from IPython.display import Audio\n", "Audio(audio_path, embed=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "303d531f-dba3-40b9-a1ff-1be92d8c188a", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "8c5eb1f0-bca3-4b9a-a8f1-2f4468c54025", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "ofa", "language": "python", "name": "ofa" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.4" } }, "nbformat": 4, "nbformat_minor": 5 }