diff --git "a/Video_Captioning.ipynb" "b/Video_Captioning.ipynb" deleted file mode 100644--- "a/Video_Captioning.ipynb" +++ /dev/null @@ -1,383 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "55df6d0d-71cf-4110-81ed-7c0d3ce58e43", - "metadata": {}, - "source": [ - "## Import" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "0abe9574-05f7-4684-b586-033827b89c32", - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "74e70729-b658-4ffd-9d8b-ae42a2d1b212", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import numpy as np\n", - "from fairseq import utils, tasks\n", - "from fairseq import checkpoint_utils\n", - "from utils.eval_utils import eval_step\n", - "from tasks.mm_tasks.caption import CaptionTask\n", - "from models.unival import UnIVALModel\n", - "from PIL import Image\n", - "\n", - "import random\n", - "from torchvision.transforms import functional as F\n", - "from torchvision.transforms import InterpolationMode\n", - "\n", - "from matplotlib import pyplot as plt\n", - "\n", - "# turn on cuda if GPU is available\n", - "use_cuda = torch.cuda.is_available()\n", - "# use fp16 only when GPU is available\n", - "use_fp16 = False\n", - "import os " - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ce03a870-2852-410e-97c4-59461d08f60a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ".register_task_cls(cls)>" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Register refcoco task\n", - "tasks.register_task('video_caption', CaptionTask)" - ] - }, - { - "cell_type": "markdown", - "id": "58361680-3e90-4fff-962e-2ff67c1e7289", - "metadata": {}, - "source": [ - "### Load model" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "id": "adb79611-7563-4fb6-a576-f31050f8438e", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "self.sample_patch_num 784\n", - "self.sample_audio_patch_num None\n", - "self.sample_video_patch_num None\n", - "self.with_cls False\n", - "Loading: all_resnext101\n", - "use bn: \n", - "load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth\n", - "_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])\n", - "unival\n", - "getattr(args, \"stop_on_max_len\", False) False\n" - ] - } - ], - "source": [ - "# Load pretrained ckpt & config\n", - "\n", - "checkpoint_path = '/data/mshukor/logs/ofa/best_models/unival_video_caption_stage_1/checkpoint_best.pt'\n", - "video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'\n", - "\n", - "overrides={\"eval_cider\":False, \"beam\":5, \"max_len_b\":22, \"no_repeat_ngram_size\":3, \"seed\":7, \"unnormalized\": False,\n", - " \"bpe_dir\":\"utils/BPE\", \"video_model_path\": video_model_path,}\n", - "\n", - "models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(\n", - " utils.split_paths(checkpoint_path),\n", - " arg_overrides=overrides\n", - " )\n", - "\n", - "# Move models to GPU\n", - "for model in models:\n", - " model.eval()\n", - " if use_fp16:\n", - " model.half()\n", - " if use_cuda and not cfg.distributed_training.pipeline_model_parallel:\n", - " model.cuda()\n", - " model.prepare_for_inference_(cfg)\n", - "\n", - "# Initialize generator\n", - "generator = task.build_generator(models, cfg.generation)" - ] - }, - { - "cell_type": "markdown", - "id": "e79aad39-1424-47d5-8cd4-6ab77ea46fb4", - "metadata": {}, - "source": [ - "### Preprocess" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "id": "576a3e84-a6aa-446d-adab-fef9499318fc", - "metadata": {}, - "outputs": [], - "source": [ - "# Image transform\n", - "from torchvision import transforms\n", - "mean = [0.5, 0.5, 0.5]\n", - "std = [0.5, 0.5, 0.5]\n", - "\n", - "\n", - "\n", - "type_transform = transforms.Lambda(lambda x: x.float().div(255.0))\n", - "patch_video_resize_transform = transforms.Compose([\n", - " transforms.CenterCrop(cfg.task.patch_frame_size),\n", - " type_transform, \n", - " transforms.Normalize(mean=mean, std=std),\n", - " ])\n", - "\n", - "# video process\n", - "from data.video_utils import VIDEO_READER_FUNCS\n", - "\n", - "video_reader = VIDEO_READER_FUNCS['decord'] \n", - "\n", - "def process_video(video_path, max_num_frames=16, num_frames=16, sample_type='rand',):\n", - " \n", - " # video \n", - " data_path = os.path.join(video_path)\n", - "\n", - " frames, frame_indices, video_duration = video_reader(\n", - " data_path, num_frames, sample_type, max_num_frames=max_num_frames\n", - " )\n", - "\n", - " patch_video = patch_video_resize_transform(frames)\n", - " patch_video = patch_video.permute(1, 0, 2, 3) # -> (C, T, h, w)\n", - "\n", - " return patch_video.unsqueeze(0)\n", - " \n", - "\n", - "# Text preprocess\n", - "bos_item = torch.LongTensor([task.src_dict.bos()])\n", - "eos_item = torch.LongTensor([task.src_dict.eos()])\n", - "pad_idx = task.src_dict.pad()\n", - "def encode_text(text, length=None, append_bos=False, append_eos=False):\n", - " s = task.tgt_dict.encode_line(\n", - " line=task.bpe.encode(text),\n", - " add_if_not_exist=False,\n", - " append_eos=False\n", - " ).long()\n", - " if length is not None:\n", - " s = s[:length]\n", - " if append_bos:\n", - " s = torch.cat([bos_item, s])\n", - " if append_eos:\n", - " s = torch.cat([s, eos_item])\n", - " return s\n", - "\n", - "# Construct input for caption task\n", - "def construct_sample(video_path):\n", - " \n", - " patch_video = process_video(video_path, max_num_frames=16, num_frames=cfg.task.num_frames, sample_type=cfg.task.sample_type,)\n", - " patch_image = torch.zeros((3, cfg.task.patch_image_size, cfg.task.patch_image_size)) \n", - " \n", - " patch_type = torch.tensor([1])\n", - " patch_mask = torch.tensor([True])\n", - " src_text = encode_text(\" what does the video describe?\", append_bos=True, append_eos=True).unsqueeze(0)\n", - " src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])\n", - " sample = {\n", - " \"id\":np.array(['42']),\n", - " \"net_input\": {\n", - " \"src_tokens\": src_text,\n", - " \"src_lengths\": src_length,\n", - " \"patch_videos\": patch_video,\n", - " \"patch_images\": patch_image,\n", - " \"patch_masks\": patch_mask,\n", - " \"patch_types\": patch_type,\n", - " }\n", - " }\n", - " return sample\n", - " \n", - "# Function to turn FP32 to FP16\n", - "def apply_half(t):\n", - " if t.dtype is torch.float32:\n", - " return t.to(dtype=torch.half)\n", - " return t" - ] - }, - { - "cell_type": "markdown", - "id": "f96f776e-9aa0-4271-b881-311851cc033c", - "metadata": {}, - "source": [ - "### Inference" - ] - }, - { - "cell_type": "code", - "execution_count": 157, - "id": "6f8ddf8c-82e2-411c-baa3-850da02f1996", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 3, 16, 384, 384])\n" - ] - }, - { - "data": { - "text/html": [ - "" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 157, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "save_dir = '/home/mshukor/ofa_adastra'\n", - "\n", - "\n", - "\n", - "\n", - "video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7019.mp4' # a man is sitting in a chair and talking\n", - "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7038.mp4' # a person is cooking something in a pan\n", - "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7021.mp4' # a group of people are playing baseball\n", - "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7068.mp4' # a man and a woman are talking to each other\n", - "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7017.mp4' # a person is playing a video game\n", - "# video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7014.mp4' # a girl is singing on the voice\n", - "\n", - "\n", - "\n", - "# video_path = '/data/mshukor/data/video/msrvtt/examples/video1065.mp4'\n", - "\n", - "# limitations\n", - "video_path = '/data/mshukor/data/video/msrvtt/examples/test/video7055.mp4' # a man is driving a car\n", - "\n", - "\n", - "sample = construct_sample(video_path)\n", - "sample = utils.move_to_cuda(sample) if use_cuda else sample\n", - "sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3690f53b-3594-4d8f-81c8-c8ed0931c00b", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 158, - "id": "4651039c-b8c0-4687-871e-b42cb13b2984", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "tensor([1], device='cuda:0')\n", - "torch.Size([1, 2048, 1, 12, 12])\n" - ] - } - ], - "source": [ - "from utils.eval_utils import eval_caption\n", - "\n", - "with torch.no_grad():\n", - " result, scores = eval_caption(task, generator, models, sample)" - ] - }, - { - "cell_type": "code", - "execution_count": 159, - "id": "712150d4-f28c-4538-870f-b33f775725d5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "a man is driving a car\n" - ] - } - ], - "source": [ - "caption = result[0]['caption']\n", - "print(caption)\n", - "\n", - "from IPython.display import Video\n", - "Video(video_path, embed=True)\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "303d531f-dba3-40b9-a1ff-1be92d8c188a", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2db0cc0-5cd2-48dd-b900-56331d53b1df", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "ofa", - "language": "python", - "name": "ofa" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -}