## Import

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
from fairseq import utils, tasks
from fairseq import checkpoint_utils
from utils.eval_utils import eval_step
from tasks.mm_tasks.caption import CaptionTask
from models.unival import UnIVALModel
from PIL import Image

import random
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode

from matplotlib import pyplot as plt

# turn on cuda if GPU is available
use_cuda = torch.cuda.is_available()
# use fp16 only when GPU is available
use_fp16 = False

In [3]:
# Register refcoco task
tasks.register_task('audio_caption', CaptionTask)

<function fairseq.tasks.register_task.<locals>.register_task_cls(cls)>

### Load model

In [64]:
# Load pretrained ckpt & config

checkpoint_path = '/data/mshukor/logs/ofa/best_models/unival_audio_caption/checkpoint_best.pt'

video_model_path = '/data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth'
audio_model_path = '/data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth'
resnet_model_path = '/data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth'



overrides={"eval_cider":False, "beam":5, "max_len_b":22, "no_repeat_ngram_size":3, "seed":7, "unnormalized": False,
           "bpe_dir":"utils/BPE", "video_model_path": video_model_path, "audio_model_path": audio_model_path, "resnet_model_path": resnet_model_path,}

models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
        utils.split_paths(checkpoint_path),
        arg_overrides=overrides
    )

# Move models to GPU
for model in models:
    model.eval()
    if use_fp16:
        model.half()
    if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
        model.cuda()
    model.prepare_for_inference_(cfg)

# Initialize generator
generator = task.build_generator(models, cfg.generation)

self.sample_patch_num 784
self.sample_audio_patch_num None
self.sample_video_patch_num None
self.with_cls False
Loading:  all_resnext101
use bn: <class 'torch.nn.modules.batchnorm.BatchNorm3d'>
load pretrained_model /data/mshukor/logs/ofa/best_models/resnext-101-kinetics.pth
_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc.weight', 'fc.bias'])
Loading:  pann_cnn14
load pretrained_model /data/mshukor/logs/ofa/best_models/Cnn14_mAP_0.431.pth
_IncompatibleKeys(missing_keys=[], unexpected_keys=['fc1.weight', 'fc1.bias', 'fc_audioset.weight', 'fc_audioset.bias'])
load resnet /data/mshukor/logs/ofa/best_models/resnet101-5d3b4d8f.pth
<All keys matched successfully>
unival
task
getattr(args, "stop_on_max_len", False) False


### Preprocess

In [65]:
# Image transform
from torchvision import transforms
import torchaudio

from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG


mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]



def process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG):

    # audio 
    data_path = audio_path



    audio_data, orig_sr = torchaudio.load(data_path)
    audio_data = torchaudio.transforms.Resample(orig_sr, sample_rate)(audio_data[0])

    sample = {}

    sample = get_audio_features(
        sample, audio_data, max_audio_len, 
        data_truncating='rand_trunc', 
        data_filling='repeatpad',
        audio_cfg=audio_cfg
    )


    waveform = sample['waveform']
    patch_audio = waveform
    
    return patch_audio.unsqueeze(0)

        
        
        
# Text preprocess
bos_item = torch.LongTensor([task.src_dict.bos()])
eos_item = torch.LongTensor([task.src_dict.eos()])
pad_idx = task.src_dict.pad()
def encode_text(text, length=None, append_bos=False, append_eos=False):
    s = task.tgt_dict.encode_line(
        line=task.bpe.encode(text),
        add_if_not_exist=False,
        append_eos=False
    ).long()
    if length is not None:
        s = s[:length]
    if append_bos:
        s = torch.cat([bos_item, s])
    if append_eos:
        s = torch.cat([s, eos_item])
    return s


# Construct input for caption task
def construct_sample(audio_path):
    
    
    patch_audio = process_audio(audio_path, sample_rate=48000, max_audio_len=480000, audio_cfg=AUDIO_CFG)
    patch_image = torch.zeros((3, cfg.task.patch_image_size, cfg.task.patch_image_size))   
    
    patch_type = torch.tensor([2])
    patch_mask = torch.tensor([True])
    src_text = encode_text(" what does the image describe?", append_bos=True, append_eos=True).unsqueeze(0)
    src_length = torch.LongTensor([s.ne(pad_idx).long().sum() for s in src_text])
    sample = {
        "id":np.array(['42']),
        "net_input": {
            "src_tokens": src_text,
            "src_lengths": src_length,
            "patch_images": patch_image,
            "patch_audios": patch_audio,
            "patch_masks": patch_mask,
            "patch_types": patch_type,
        }
    }
    return sample
  
# Function to turn FP32 to FP16
def apply_half(t):
    if t.dtype is torch.float32:
        return t.to(dtype=torch.half)
    return t

### Inference

In [114]:
save_dir = '/home/mshukor/ofa_adastra'



audio_path = '/data/mshukor/data/audiocaps/test/KSHpYhuTotY.wav' # A man talks while bees fly
# audio_path = '/data/mshukor/data/audiocaps/test/6cS0FsUM-cQ.wav' # A cat is meowing and a man is speaking
# audio_path = '/data/mshukor/data/audiocaps/test/6CDl4CqOgMg.wav'  # A dog pants and whimpers
# audio_path = '/data/mshukor/data/audiocaps/test/_BSmz3SEW1w.wav' # Pigeons coo and flap their wings
# audio_path = '/data/mshukor/data/audiocaps/test/ZsTZ7jqbd9M.wav' # A man speaking with birds chirping in the background
# audio_path = '/data/mshukor/data/audiocaps/test/5OM3tJh51pE.wav' # A woman giving a speech

# audio_path = '/data/mshukor/data/audiocaps/test/AJtNitYMa1I.wav' # Food sizzling in a pan

audio_path = '/data/mshukor/data/audiocaps/test/3MoF8myFs8Y.wav' # Wind blows hard and waves crash against a shoreline
audio_path = '/data/mshukor/data/audiocaps/test/350OCezayrk.wav' # A motor vehicle engine is idling and vibrating


## limitations
# audio_path = '/data/mshukor/data/audiocaps/test/EBCH7TPgiPc.wav' # A motor vehicle engine is running and revving and an adult male speaks in the background


sample = construct_sample(audio_path)
sample = utils.move_to_cuda(sample) if use_cuda else sample
sample = utils.apply_to_sample(apply_half, sample) if use_fp16 else sample
print(sample['net_input']['patch_audios'].shape)


torch.Size([1, 480000])


In [115]:
from utils.eval_utils import eval_caption

with torch.no_grad():
    result, scores = eval_caption(task, generator, models, sample)

In [116]:
caption = result[0]['caption']
print(caption)

from IPython.display import Audio
Audio(audio_path, embed=True)

A motor vehicle engine is idling and vibrating
