ChatVID / model /vision /Vid2SeqCaptioner.py
Yiqin's picture
init
6ef31de
from typing import Any
from model.utils import extract_clip_feature_single_video_fps, generate, ScenicCall, ScenicModel
from config import vid2seq_config
import torch
import sys, os
from pathlib import Path
# append current path to sys.path
sys.path.append(str(Path(__file__).parent.parent.parent / "scenic"))
from scenic.projects.vid2seq.playground import generate as vid2seq_generate
class Flag(object):
pass
class Vid2SeqCaptioner:
"""Vid2SeqCaptioner is a class that uses a video to generate a caption for the video.
Description:
It uses the Google Scenic Vid2Seq as the base model. Note that Scenic is a project designed to use with TPUs. And GPU resources(70G VMeomry at least) maybe be not enough to run the model. So, we need to use the CPU to run the model.
"""
def __init__(self, config):
self.config = config
flags = Flag()
flags.workdir = self.config['work_dir']
flags.config = vid2seq_config.get_config()
# flags.config = self.config['config_path']
flags.data_dir = self.config['output_path']
flags.ckpt_dir = self.config['checkpoint_path']
self.model = ScenicModel(flags)
def __call__(self, video_path):
self._preprocess(video_path)
return self.model()
# call = ScenicCall(vid2seq_generate, flags)
# return call()
def _preprocess(self, video_path):
"""Preprocess the video.
Description:
Pipeline: CLIP -> *.npy -> *.csv(video) -> generate_from_file.py -> file0000-0003.tfrecord
Args:
video_path: The path of the video.
"""
# Extract CLIP features first
device = "cuda" if torch.cuda.is_available() else "cpu"
video_feat, video_info = extract_clip_feature_single_video_fps(
video_path=video_path,
clip_ckpt_path=self.config['clip_path'],
device=device
)
# get numpy array
video_feat = video_feat.cpu()
video_feat = video_feat.numpy()
# create a dict to store the video info
video_info_dict = {
'basename' : 'test',
'output_path' : self.config['output_path'],
'asr_start' : None,
'asr_end' : None,
'asr_string' : None,
'video_id' : video_path.split('/')[-1].split('.')[0],
'features' : video_feat,
'duration' : video_info['total_frames'] / video_info['avg_fps'] * 1000000,
}
# begin to generate tfrecord file
generate(video_info_dict)
print("tfrecord file generated at {}".format(self.config['output_path']))