diff --git "a/Audio_Captioning.ipynb" "b/Audio_Captioning.ipynb" deleted file mode 100644--- "a/Audio_Captioning.ipynb" +++ /dev/null @@ -1,387 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "55df6d0d-71cf-4110-81ed-7c0d3ce58e43", - "metadata": {}, - "source": [ - "## Import" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0abe9574-05f7-4684-b586-033827b89c32", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "74e70729-b658-4ffd-9d8b-ae42a2d1b212", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import numpy as np\n", - "from fairseq import utils, tasks\n", - "from fairseq import checkpoint_utils\n", - "from utils.eval_utils import eval_step\n", - "from tasks.mm_tasks.caption import CaptionTask\n", - "from models.unival import UnIVALModel\n", - "from PIL import Image\n", - "\n", - "import random\n", - "from torchvision.transforms import functional as F\n", - "from torchvision.transforms import InterpolationMode\n", - "\n", - "from matplotlib import pyplot as plt\n", - "\n", - "# turn on cuda if GPU is available\n", - "use_cuda = torch.cuda.is_available()\n", - "# use fp16 only when GPU is available\n", - "use_fp16 = False" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ce03a870-2852-410e-97c4-59461d08f60a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ".register_task_cls(cls)>" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Register refcoco task\n", - "tasks.register_task('audio_caption', CaptionTask)" - ] - }, - { - "cell_type": "markdown", - "id": "58361680-3e90-4fff-962e-2ff67c1e7289", - "metadata": {}, - "source": [ - "### Load model" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "id": "adb79611-7563-4fb6-a576-f31050f8438e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "self.sample_patch_num 784\n", - "self.sample_audio_patch_num None\n", - "self.sample_video_patch_num None\n", - "self.with_cls False\n", - "Loading: all_resnext101\n", - "use bn: \n", - "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n", - "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n", - "Loading: pann_cnn14\n", - "load pretrained_model /data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth\n", - "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc1.weight', 'fc1.bias', 'fc_audioset.weight', 'fc_audioset.bias'])\n", - "load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth\n", - "\n", - "unival\n", - "task\n", - "getattr(args, \"stop_on_max_len\", False) False\n" - ] - } - ], - "source": [ - "# Load pretrained ckpt & config\n", - "\n", - "checkpoint_path = '/data/mshukor/logs/ofa/best_models/unival_audio_caption/checkpoint_best.pt'\n", - "\n", - "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n", - "audio_model_path = '/data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth'\n", - "resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'\n", - "\n", - "\n", - "\n", - "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":22, \"no_repeat_ngram_size\":3, \"seed\":7, \"unnormalized\": False,\n", - " \"bpe_dir\":\"utils/BPE\", \"video_model_path\": video_model_path, \"audio_model_path\": audio_model_path, \"resnet_model_path\": resnet_model_path,}\n", - "\n", - "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n", - " utils.split_paths(checkpoint_path),\n", - " arg_overrides=overrides\n", - " )\n", - "\n", - "# Move models to GPU\n", - "for model in models:\n", - " model.eval()\n", - " if use_fp16:\n", - " model.half()\n", - " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n", - " model.cuda()\n", - " model.prepare_for_inference_(cfg)\n", - "\n", - "# Initialize generator\n", - "generator = task.build_generator(models, cfg.generation)" - ] - }, - { - "cell_type": "markdown", - "id": "e79aad39-1424-47d5-8cd4-6ab77ea46fb4", - "metadata": {}, - "source": [ - "### Preprocess" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "id": "576a3e84-a6aa-446d-adab-fef9499318fc", - "metadata": {}, - "outputs": [], - "source": [ - "# Image transform\n", - "from torchvision import transforms\n", - "import torchaudio\n", - "\n", - "from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG\n", - "\n", - "\n", - "mean = [0.5, 0.5, 0.5]\n", - "std = [0.5, 0.5, 0.5]\n", - "\n", - "\n", - "\n", - "def process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG):\n", - "\n", - " # audio \n", - " data_path = audio_path\n", - "\n", - "\n", - "\n", - " audio_data, orig_sr = torchaudio.load(data_path)\n", - " audio_data = torchaudio.transforms.Resample(orig_sr, sample_rate)(audio_data[0])\n", - "\n", - " sample = {}\n", - "\n", - " sample = get_audio_features(\n", - " sample, audio_data, max_audio_len, \n", - " data_truncating='rand_trunc', \n", - " data_filling='repeatpad',\n", - " audio_cfg=audio_cfg\n", - " )\n", - "\n", - "\n", - " waveform = sample['waveform']\n", - " patch_audio = waveform\n", - " \n", - " return patch_audio.unsqueeze(0)\n", - "\n", - " \n", - " \n", - " \n", - "# Text preprocess\n", - "bos_item = torch.LongTensor([task.src_dict.bos()])\n", - "eos_item = torch.LongTensor([task.src_dict.eos()])\n", - "pad_idx = task.src_dict.pad()\n", - "def encode_text(text, length=None, append_bos=False, append_eos=False):\n", - " s = task.tgt_dict.encode_line(\n", - " line=task.bpe.encode(text),\n", - " add_if_not_exist=False,\n", - " append_eos=False\n", - " ).long()\n", - " if length is not None:\n", - " s = s[:length]\n", - " if append_bos:\n", - " s = torch.cat([bos_item, s])\n", - " if append_eos:\n", - " s = torch.cat([s, eos_item])\n", - " return s\n", - "\n", - "\n", - "# Construct input for caption task\n", - "def construct_sample(audio_path):\n", - " \n", - " \n", - " patch_audio = process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG)\n", - " patch_image = torch.zeros((3, cfg.task.patch_image_size, cfg.task.patch_image_size)) \n", - " \n", - " patch_type = torch.tensor([2])\n", - " patch_mask = torch.tensor([True])\n", - " src_text = encode_text(\" what does the image describe?\", append_bos=True, append_eos=True).unsqueeze(0)\n", - " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n", - " sample = {\n", - " \"id\":np.array(['42']),\n", - " \"net_input\": {\n", - " \"src_tokens\": src_text,\n", - " \"src_lengths\": src_length,\n", - " \"patch_images\": patch_image,\n", - " \"patch_audios\": patch_audio,\n", - " \"patch_masks\": patch_mask,\n", - " \"patch_types\": patch_type,\n", - " }\n", - " }\n", - " return sample\n", - " \n", - "# Function to turn FP32 to FP16\n", - "def apply_half(t):\n", - " if t.dtype is torch.float32:\n", - " return t.to(dtype=torch.half)\n", - " return t" - ] - }, - { - "cell_type": "markdown", - "id": "f96f776e-9aa0-4271-b881-311851cc033c", - "metadata": {}, - "source": [ - "### Inference" - ] - }, - { - "cell_type": "code", - "execution_count": 114, - "id": "6f8ddf8c-82e2-411c-baa3-850da02f1996", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 480000])\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - " \n", - " " - ], - "text/plain": [ - "" - ] - }, - "execution_count": 114, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "save_dir = '/home/mshukor/ofa_adastra'\n", - "\n", - "\n", - "\n", - "audio_path = '/data/mshukor/data/audiocaps/test/KSHpYhuTotY.wav' # A man talks while bees fly\n", - "# audio_path = '/data/mshukor/data/audiocaps/test/6cS0FsUM-cQ.wav' # A cat is meowing and a man is speaking\n", - "# audio_path = '/data/mshukor/data/audiocaps/test/6CDl4CqOgMg.wav' # A dog pants and whimpers\n", - "# audio_path = '/data/mshukor/data/audiocaps/test/_BSmz3SEW1w.wav' # Pigeons coo and flap their wings\n", - "# audio_path = '/data/mshukor/data/audiocaps/test/ZsTZ7jqbd9M.wav' # A man speaking with birds chirping in the background\n", - "# audio_path = '/data/mshukor/data/audiocaps/test/5OM3tJh51pE.wav' # A woman giving a speech\n", - "\n", - "# audio_path = '/data/mshukor/data/audiocaps/test/AJtNitYMa1I.wav' # Food sizzling in a pan\n", - "\n", - "audio_path = '/data/mshukor/data/audiocaps/test/3MoF8myFs8Y.wav' # Wind blows hard and waves crash against a shoreline\n", - "audio_path = '/data/mshukor/data/audiocaps/test/350OCezayrk.wav' # A motor vehicle engine is idling and vibrating\n", - "\n", - "\n", - "## limitations\n", - "# audio_path = '/data/mshukor/data/audiocaps/test/EBCH7TPgiPc.wav' # A motor vehicle engine is running and revving and an adult male speaks in the background\n", - "\n", - "\n", - "sample = construct_sample(audio_path)\n", - "sample = utils.move_to_cuda(sample) if use_cuda else sample\n", - "sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample\n", - "print(sample['net_input']['patch_audios'].shape)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 115, - "id": "4651039c-b8c0-4687-871e-b42cb13b2984", - "metadata": {}, - "outputs": [], - "source": [ - "from utils.eval_utils import eval_caption\n", - "\n", - "with torch.no_grad():\n", - " result, scores = eval_caption(task, generator, models, sample)" - ] - }, - { - "cell_type": "code", - "execution_count": 116, - "id": "712150d4-f28c-4538-870f-b33f775725d5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "A motor vehicle engine is idling and vibrating\n" - ] - } - ], - "source": [ - "caption = result[0]['caption']\n", - "print(caption)\n", - "\n", - "from IPython.display import Audio\n", - "Audio(audio_path, embed=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "303d531f-dba3-40b9-a1ff-1be92d8c188a", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8c5eb1f0-bca3-4b9a-a8f1-2f4468c54025", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "ofa", - "language": "python", - "name": "ofa" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}