diff --git "a/Audio_Captioning.ipynb" "b/Audio_Captioning.ipynb" new file mode 100644--- /dev/null +++ "b/Audio_Captioning.ipynb" @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "55df6d0d-71cf-4110-81ed-7c0d3ce58e43", + "metadata": {}, + "source": [ + "## Import" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "0abe9574-05f7-4684-b586-033827b89c32", + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "74e70729-b658-4ffd-9d8b-ae42a2d1b212", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import numpy as np\n", + "from fairseq import utils, tasks\n", + "from fairseq import checkpoint_utils\n", + "from utils.eval_utils import eval_step\n", + "from tasks.mm_tasks.caption import CaptionTask\n", + "from models.unival import UnIVALModel\n", + "from PIL import Image\n", + "\n", + "import random\n", + "from torchvision.transforms import functional as F\n", + "from torchvision.transforms import InterpolationMode\n", + "\n", + "from matplotlib import pyplot as plt\n", + "\n", + "# turn on cuda if GPU is available\n", + "use_cuda = torch.cuda.is_available()\n", + "# use fp16 only when GPU is available\n", + "use_fp16 = False" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ce03a870-2852-410e-97c4-59461d08f60a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + ".register_task_cls(cls)>" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Register refcoco task\n", + "tasks.register_task('audio_caption', CaptionTask)" + ] + }, + { + "cell_type": "markdown", + "id": "58361680-3e90-4fff-962e-2ff67c1e7289", + "metadata": {}, + "source": [ + "### Load model" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "id": "adb79611-7563-4fb6-a576-f31050f8438e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "self.sample_patch_num 784\n", + "self.sample_audio_patch_num None\n", + "self.sample_video_patch_num None\n", + "self.with_cls False\n", + "Loading: all_resnext101\n", + "use bn: \n", + "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n", + "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n", + "Loading: pann_cnn14\n", + "load pretrained_model /data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth\n", + "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc1.weight', 'fc1.bias', 'fc_audioset.weight', 'fc_audioset.bias'])\n", + "load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth\n", + "\n", + "unival\n", + "task\n", + "getattr(args, \"stop_on_max_len\", False) False\n" + ] + } + ], + "source": [ + "# Load pretrained ckpt & config\n", + "\n", + "checkpoint_path = '/data/mshukor/logs/ofa/best_models/unival_audio_caption/checkpoint_best.pt'\n", + "\n", + "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n", + "audio_model_path = '/data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth'\n", + "resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n", + "\n", + "\n", + "\n", + "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":22, \"no_repeat_ngram_size\":3, \"seed\":7, \"unnormalized\": False,\n", + " \"bpe_dir\":\"utils/BPE\", \"video_model_path\": video_model_path, \"audio_model_path\": audio_model_path, \"resnet_model_path\": resnet_model_path,}\n", + "\n", + "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n", + " utils.split_paths(checkpoint_path),\n", + " arg_overrides=overrides\n", + " )\n", + "\n", + "# Move models to GPU\n", + "for model in models:\n", + " model.eval()\n", + " if use_fp16:\n", + " model.half()\n", + " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n", + " model.cuda()\n", + " model.prepare_for_inference_(cfg)\n", + "\n", + "# Initialize generator\n", + "generator = task.build_generator(models, cfg.generation)" + ] + }, + { + "cell_type": "markdown", + "id": "e79aad39-1424-47d5-8cd4-6ab77ea46fb4", + "metadata": {}, + "source": [ + "### Preprocess" + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "id": "576a3e84-a6aa-446d-adab-fef9499318fc", + "metadata": {}, + "outputs": [], + "source": [ + "# Image transform\n", + "from torchvision import transforms\n", + "import torchaudio\n", + "\n", + "from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG\n", + "\n", + "\n", + "mean = [0.5, 0.5, 0.5]\n", + "std = [0.5, 0.5, 0.5]\n", + "\n", + "\n", + "\n", + "def process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG):\n", + "\n", + " # audio \n", + " data_path = audio_path\n", + "\n", + "\n", + "\n", + " audio_data, orig_sr = torchaudio.load(data_path)\n", + " audio_data = torchaudio.transforms.Resample(orig_sr, sample_rate)(audio_data[0])\n", + "\n", + " sample = {}\n", + "\n", + " sample = get_audio_features(\n", + " sample, audio_data, max_audio_len, \n", + " data_truncating='rand_trunc', \n", + " data_filling='repeatpad',\n", + " audio_cfg=audio_cfg\n", + " )\n", + "\n", + "\n", + " waveform = sample['waveform']\n", + " patch_audio = waveform\n", + " \n", + " return patch_audio.unsqueeze(0)\n", + "\n", + " \n", + " \n", + " \n", + "# Text preprocess\n", + "bos_item = torch.LongTensor([task.src_dict.bos()])\n", + "eos_item = torch.LongTensor([task.src_dict.eos()])\n", + "pad_idx = task.src_dict.pad()\n", + "def encode_text(text, length=None, append_bos=False, append_eos=False):\n", + " s = task.tgt_dict.encode_line(\n", + " line=task.bpe.encode(text),\n", + " add_if_not_exist=False,\n", + " append_eos=False\n", + " ).long()\n", + " if length is not None:\n", + " s = s[:length]\n", + " if append_bos:\n", + " s = torch.cat([bos_item, s])\n", + " if append_eos:\n", + " s = torch.cat([s, eos_item])\n", + " return s\n", + "\n", + "\n", + "# Construct input for caption task\n", + "def construct_sample(audio_path):\n", + " \n", + " \n", + " patch_audio = process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG)\n", + " patch_image = torch.zeros((3, cfg.task.patch_image_size, cfg.task.patch_image_size)) \n", + " \n", + " patch_type = torch.tensor([2])\n", + " patch_mask = torch.tensor([True])\n", + " src_text = encode_text(\" what does the image describe?\", append_bos=True, append_eos=True).unsqueeze(0)\n", + " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n", + " sample = {\n", + " \"id\":np.array(['42']),\n", + " \"net_input\": {\n", + " \"src_tokens\": src_text,\n", + " \"src_lengths\": src_length,\n", + " \"patch_images\": patch_image,\n", + " \"patch_audios\": patch_audio,\n", + " \"patch_masks\": patch_mask,\n", + " \"patch_types\": patch_type,\n", + " }\n", + " }\n", + " return sample\n", + " \n", + "# Function to turn FP32 to FP16\n", + "def apply_half(t):\n", + " if t.dtype is torch.float32:\n", + " return t.to(dtype=torch.half)\n", + " return t" + ] + }, + { + "cell_type": "markdown", + "id": "f96f776e-9aa0-4271-b881-311851cc033c", + "metadata": {}, + "source": [ + "### Inference" + ] + }, + { + "cell_type": "code", + "execution_count": 114, + "id": "6f8ddf8c-82e2-411c-baa3-850da02f1996", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 480000])\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "execution_count": 114, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "save_dir = '/home/mshukor/ofa_adastra'\n", + "\n", + "\n", + "\n", + "audio_path = '/data/mshukor/data/audiocaps/test/KSHpYhuTotY.wav' # A man talks while bees fly\n", + "# audio_path = '/data/mshukor/data/audiocaps/test/6cS0FsUM-cQ.wav' # A cat is meowing and a man is speaking\n", + "# audio_path = '/data/mshukor/data/audiocaps/test/6CDl4CqOgMg.wav' # A dog pants and whimpers\n", + "# audio_path = '/data/mshukor/data/audiocaps/test/_BSmz3SEW1w.wav' # Pigeons coo and flap their wings\n", + "# audio_path = '/data/mshukor/data/audiocaps/test/ZsTZ7jqbd9M.wav' # A man speaking with birds chirping in the background\n", + "# audio_path = '/data/mshukor/data/audiocaps/test/5OM3tJh51pE.wav' # A woman giving a speech\n", + "\n", + "# audio_path = '/data/mshukor/data/audiocaps/test/AJtNitYMa1I.wav' # Food sizzling in a pan\n", + "\n", + "audio_path = '/data/mshukor/data/audiocaps/test/3MoF8myFs8Y.wav' # Wind blows hard and waves crash against a shoreline\n", + "audio_path = '/data/mshukor/data/audiocaps/test/350OCezayrk.wav' # A motor vehicle engine is idling and vibrating\n", + "\n", + "\n", + "## limitations\n", + "# audio_path = '/data/mshukor/data/audiocaps/test/EBCH7TPgiPc.wav' # A motor vehicle engine is running and revving and an adult male speaks in the background\n", + "\n", + "\n", + "sample = construct_sample(audio_path)\n", + "sample = utils.move_to_cuda(sample) if use_cuda else sample\n", + "sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n", + "print(sample['net_input']['patch_audios'].shape)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 115, + "id": "4651039c-b8c0-4687-871e-b42cb13b2984", + "metadata": {}, + "outputs": [], + "source": [ + "from utils.eval_utils import eval_caption\n", + "\n", + "with torch.no_grad():\n", + " result, scores = eval_caption(task, generator, models, sample)" + ] + }, + { + "cell_type": "code", + "execution_count": 116, + "id": "712150d4-f28c-4538-870f-b33f775725d5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A motor vehicle engine is idling and vibrating\n" + ] + } + ], + "source": [ + "caption = result[0]['caption']\n", + "print(caption)\n", + "\n", + "from IPython.display import Audio\n", + "Audio(audio_path, embed=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "303d531f-dba3-40b9-a1ff-1be92d8c188a", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c5eb1f0-bca3-4b9a-a8f1-2f4468c54025", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ofa", + "language": "python", + "name": "ofa" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}