UnIVAL / data /mm_data /audio_caption_dataset.py
mshukor
init
26fd00c
raw
history blame
7.77 kB
# Modified from OFA code.
# Copyright 2022 The OFA-Sys Team.
# All rights reserved.
# This source code is licensed under the Apache 2.0 license
# found in the LICENSE file in the root directory.
from io import BytesIO
import logging
import warnings
import string
import numpy as np
import torch
from PIL import Image, ImageFile
from data import data_utils
from data.ofa_dataset import OFADataset
import os
import random
import soundfile as sf
import torchaudio
from data.audio_utils import get_audio_features, int16_to_float32, float32_to_int16, AUDIO_CFG
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
id = np.array([s["id"] for s in samples])
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
patch_videos = torch.stack([sample['patch_video'] for sample in samples], dim=0)
patch_types = torch.cat([sample['patch_type'] for sample in samples])
patch_audios = torch.stack([sample['patch_audio'] for sample in samples], dim=0)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"patch_images": patch_images,
"patch_masks": patch_masks,
"prev_output_tokens": prev_output_tokens,
"patch_videos": patch_videos,
"patch_types": patch_types,
"patch_audios": patch_audios,
},
"target": target,
}
return batch
class CaptionDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=128,
max_tgt_length=30,
patch_image_size=224,
imagenet_default_mean_and_std=False,
scst=False,
image_dir='/gpfsscratch/rech/dyf/ugz83ue/data',
audio_cfg=AUDIO_CFG,
max_audio_len = 480000,
num_frames=4,
sample_rate = 48000,
audio_sample_rate=False,
ast=False,
mode='train',
mel_bins=64,
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.patch_image_size = patch_image_size
self.scst = scst
self.image_dir = image_dir
self.sample_rate = sample_rate
self.transtab = str.maketrans({key: None for key in string.punctuation})
# video
self.num_frames = num_frames
# audio
self.audio_cfg = audio_cfg
self.max_audio_len = max_audio_len
self.audio_sample_rate = audio_sample_rate
if type(bpe).__name__ == 'GPT2BPE':
self.prompt = " what does the video describe?"
else:
raise NotImplemented
# for AST encoder
self.ast = ast
self.target_length = 1024 # 1024
self.mode = split # train
self.freqm_p = 24
self.timem_p = 96
self.skip_norm = False
self.noise = False
self.norm_mean = -4.2677393
self.norm_std = 4.5689974
self.freqm = torchaudio.transforms.FrequencyMasking(self.freqm_p)
self.timem = torchaudio.transforms.TimeMasking(self.timem_p)
self.mel_bins = mel_bins
def __getitem__(self, index):
uniq_id, image, caption = self.dataset[index]
# audio
image_path = os.path.join(self.image_dir, image)
data_path = image_path
try:
# load the waveform of the shape (T,), should resample to 48000
if not self.audio_sample_rate:
audio_data, orig_sr = sf.read(data_path) # no sample rate
if audio_data.ndim>1:
audio_data = np.mean(audio_data,axis=1)
audio_data = int16_to_float32(float32_to_int16(audio_data))
audio_data = torch.tensor(audio_data).float() # (T)
else:
audio_data, orig_sr = torchaudio.load(data_path)
audio_data = torchaudio.transforms.Resample(orig_sr, self.sample_rate)(audio_data[0])
sample = {}
sample = get_audio_features(
sample, audio_data, self.max_audio_len,
data_truncating='rand_trunc',
data_filling='repeatpad',
audio_cfg=self.audio_cfg
)
except Exception as e:
new_index = random.randint(0, len(self) - 1)
logger.warning(
f"Caught exception {e} when loading video {data_path}, "
f"randomly sample a new video as replacement"
)
return self.__getitem__(new_index)
waveform = sample['waveform']
patch_audio = waveform
patch_type = torch.tensor([2])
patch_image = torch.zeros((3, self.patch_image_size, self.patch_image_size))
patch_video = torch.zeros((3, self.num_frames, self.patch_image_size, self.patch_image_size))
patch_mask = torch.tensor([True])
if self.split == 'train' and not self.scst:
caption = caption.translate(self.transtab).strip()
caption_token_list = caption.strip().split()
tgt_caption = ' '.join(caption_token_list[:self.max_tgt_length])
else:
caption = ' '.join(caption.strip().split())
caption_list = [cap.translate(self.transtab).strip() for cap in caption.strip().split('&&')]
tgt_caption = '&&'.join(caption_list)
src_item = self.encode_text(self.prompt)
tgt_item = self.encode_text(" {}".format(tgt_caption))
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
example = {
"id": uniq_id,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"patch_type": patch_type,
"patch_video": patch_video,
"patch_audio": patch_audio,
}
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch containing the data of the task
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)