PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
d28af7f verified
raw
history blame
32.3 kB
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch
import math
import pickle
import random
import os
import numpy as np
from collections import deque
from typing import Optional, Tuple, List
from .processor import (
Processor,
MetaProcessor,
TextProcessor,
Aligner,
MMAttentionMask2DProcessor
)
from ..utils import ShardedTensor
class How2MetaProcessor(MetaProcessor):
def __init__(self, config):
super().__init__(config)
path = self._get_split_path(config)
with open(path) as fd:
self.data = [line.strip() for line in fd]
def __getitem__(self, idx):
video_id = self.data[idx]
return video_id, video_id
class ShardedHow2MetaProcessor(How2MetaProcessor):
def __init__(self, config):
super().__init__(config)
self.split = str(config.split)
self.vfeat_dir = config.vfeat_dir
self._init_shard()
def _init_shard(self):
if self.split == "train":
meta_fn = os.path.join(self.vfeat_dir, "train" + "_meta.pkl")
with open(meta_fn, "rb") as fr:
meta = pickle.load(fr)
elif self.split == "valid":
meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
with open(meta_fn, "rb") as fr:
meta = pickle.load(fr)
elif self.split == "test":
print("use how2 val as test.")
meta_fn = os.path.join(self.vfeat_dir, "val" + "_meta.pkl")
with open(meta_fn, "rb") as fr:
meta = pickle.load(fr)
else:
raise ValueError("unsupported for MetaProcessor:", self.split)
video_id_to_shard = {}
for shard_id in meta:
for video_idx, video_id in enumerate(meta[shard_id]):
video_id_to_shard[video_id] = (shard_id, video_idx)
self.video_id_to_shard = video_id_to_shard
def __getitem__(self, idx):
video_id, video_id = super().__getitem__(idx)
shard_id, shard_idx = self.video_id_to_shard[video_id]
meta = (video_id, idx, shard_id, shard_idx)
return meta, meta
class ShardedVideoProcessor(Processor):
"""
mmaped shards of numpy video features.
"""
def __init__(self, config):
self.split = str(config.split)
self.vfeat_dir = config.vfeat_dir
def __call__(self, video_id):
_, _, shard_id, video_idx = video_id
if self.split == "train":
shard = ShardedTensor.load(
os.path.join(self.vfeat_dir, "train" + "_" + str(shard_id)),
"r"
)
elif self.split == "valid":
shard = ShardedTensor.load(
os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
"r"
)
elif self.split == "test":
shard = ShardedTensor.load(
os.path.join(self.vfeat_dir, "val" + "_" + str(shard_id)),
"r"
)
else:
raise ValueError("unknown split", self.split)
feat = shard[video_idx]
return feat
class ShardedTextProcessor(Processor):
def __init__(self, config):
self.tfeat_dir = str(config.tfeat_dir)
self.split = str(config.split)
def __call__(self, video_id):
_, _, shard_id, shard_idx = video_id
if self.split == "train":
target_path = self.tfeat_dir + "train" + "_" + str(shard_id)
elif self.split == "valid":
target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
elif self.split == "test":
target_path = self.tfeat_dir + "val" + "_" + str(shard_id)
else:
raise ValueError("unknown split", self.split)
startend = ShardedTensor.load(
target_path + ".startends", "r")[shard_idx]
cap_ids = ShardedTensor.load(
target_path + ".caps_ids", "r")[shard_idx]
cap = []
for clip_idx in range(len(cap_ids)):
clip = cap_ids[clip_idx]
cap.append(clip[clip != -1].tolist())
start, end = startend[:, 0].tolist(), startend[:, 1].tolist()
return {"start": start, "end": end, "cap": cap}
class FixedLenAligner(Aligner):
"""
In the model we assume text is on the left (closer to BERT formulation)
and video is on the right.
We fix the total length of text + video.
max_video_len is in number of secs.
max_text_len is in number of tokens.
special tokens formats:
we use the format [CLS] [SEP] text tokens [SEP] [PAD] ...
[CLS] will be splitted out into:
[CLS] video tokens [SEP] text tokens [SEP] [PAD] ...
token_type_ids will be generated by the model (for now).
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
| first sequence | second sequence |
so each sequence owns a [SEP] token for no-ops.
"""
def __init__(self, config):
super().__init__(config)
self.text_clip_sampler = TextClipSamplingProcessor(
self.max_len - self.max_video_len - 3
)
"""
decide subsampling:
`config.subsampling` will change batch_size in trainer.
`config.clip_per_video` (used by RetriTask) doesn't
change batch_size in trainer.
"""
subsampling = config.subsampling \
if config.subsampling is not None else None
if config.clip_per_video is not None:
subsampling = config.clip_per_video
self.subsampling = subsampling
def _get_text_maxlen(self):
# use max text len
return self.text_clip_sampler.max_text_len
def __call__(self, video_id, video_feature, text_feature):
from transformers import default_data_collator
video_idx = video_id[1]
if self.subsampling is not None and self.subsampling >= 1:
batch = []
for _ in range(self.subsampling):
centerclip_idx = random.randint(
0, len(text_feature["start"]) - 1)
batch.append(
self.sampling(
video_idx,
video_feature,
text_feature,
centerclip_idx,
self._get_text_maxlen()
))
batch = self.batch_post_processing(batch, video_feature)
batch = default_data_collator(batch)
else:
raise ValueError(
"dataset.subsampling must be >= 1 for efficient video loading.")
batch = self.sampling(video_idx, video_feature, text_feature)
batch = self.batch_post_processing(batch, video_feature)
batch["video_id"] = video_id if isinstance(video_id, str) \
else video_id[0]
# e2e: make sure frame ids is into tensor.
assert torch.is_tensor(batch["vfeats"])
return batch
def sampling(
self,
video_idx,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
text_clip_indexs = self.text_clip_sampler(
text_feature, centerclip_idx,
sampled_max_text_len
)
if isinstance(video_feature, np.ndarray):
video_len = len(video_feature)
else:
video_len = math.ceil(text_feature["end"][-1])
video_end = min(
math.ceil(text_feature["end"][text_clip_indexs[-1]]),
video_len
)
video_start = max(
min(
math.floor(text_feature["start"][text_clip_indexs[0]]),
video_end),
0
)
video_clips = {"start": [video_start], "end": [video_end]}
# tensorize.
vfeats, vmasks = self._build_video_seq(
video_feature, video_clips
)
caps, cmasks = self._build_text_seq(
text_feature, text_clip_indexs
)
text_start = text_clip_indexs[0]
text_end = text_clip_indexs[-1] + 1
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats,
"vmasks": vmasks,
"video_start": video_start,
"video_end": video_end,
"text_start": text_start,
"text_end": text_end,
}
class VariedLenAligner(FixedLenAligner):
def __init__(self, config):
super().__init__(config)
self.sampled_min_len = config.sampled_min_len
self.sampled_max_len = config.sampled_max_len
def _get_text_maxlen(self):
return random.randint(self.sampled_min_len, self.sampled_max_len)
class StartClipAligner(VariedLenAligner):
def sampling(
self,
video_idx,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
return super().sampling(
video_idx, video_feature, text_feature, 0)
class OverlappedAligner(VariedLenAligner):
"""video clip and text clip has overlappings
but may not be the same start/end."""
def __init__(self, config):
super().__init__(config)
self.sampled_video_min_len = config.sampled_video_min_len
self.sampled_video_max_len = config.sampled_video_max_len
self.video_clip_sampler = VideoClipSamplingProcessor()
def _get_video_maxlen(self):
return random.randint(
self.sampled_video_min_len, self.sampled_video_max_len)
def sampling(
self,
video_idx,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
text_clip_indexs = self.text_clip_sampler(
text_feature, centerclip_idx,
sampled_max_text_len
)
if isinstance(video_feature, np.ndarray):
video_len = len(video_feature)
else:
video_len = math.ceil(text_feature["end"][-1])
low = math.floor(text_feature["start"][text_clip_indexs[0]])
high = math.ceil(text_feature["end"][text_clip_indexs[-1]])
if low < high:
center = random.randint(low, high)
else:
center = int((low + high) // 2)
center = max(0, min(video_feature.shape[0] - 1, center))
assert 0 <= center < video_feature.shape[0]
video_clips = self.video_clip_sampler(
video_len, self._get_video_maxlen(), center
)
video_start = video_clips["start"][0]
video_end = video_clips["end"][0]
# tensorize.
vfeats, vmasks = self._build_video_seq(
video_feature, video_clips
)
caps, cmasks = self._build_text_seq(
text_feature, text_clip_indexs
)
text_start = text_clip_indexs[0]
text_end = text_clip_indexs[-1] + 1
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats,
"vmasks": vmasks,
"video_start": video_start,
"video_end": video_end,
"text_start": text_start,
"text_end": text_end,
}
class MFMMLMAligner(FixedLenAligner):
"""
`FixedLenAligner` with Masked Language Model and Masked Frame Model.
"""
def __init__(self, config):
super().__init__(config)
keep_prob = config.keep_prob if config.keep_prob is not None else 1.0
self.text_clip_sampler = TextClipSamplingProcessor(
self.max_len - self.max_video_len - 3, keep_prob
)
self.sampled_min_len = config.sampled_min_len
self.sampled_max_len = config.sampled_max_len
self.masked_token_sampler = TextMaskingProcessor(config)
self.mm_type = config.mm_type \
if config.mm_type is not None else "full"
self.attnmasker = MMAttentionMask2DProcessor() \
if self.mm_type == "textgen" else None
self.masked_frame_sampler = FrameMaskingProcessor(config)
self.lazy_vfeat_mask = (
False if config.lazy_vfeat_mask is None else config.lazy_vfeat_mask
)
self.mm_prob = config.mm_prob if config.mm_prob is not None else 0.
def __call__(self, video_id, video_feature, text_feature):
from transformers import default_data_collator
if self.subsampling is not None and self.subsampling > 1:
batch = []
for _ in range(self.subsampling):
centerclip_idx = random.randint(
0, len(text_feature["start"]) - 1)
sampled_max_text_len = random.randint(
self.sampled_min_len, self.sampled_max_len
)
batch.append(
self.sampling(
video_id,
video_feature,
text_feature,
centerclip_idx,
sampled_max_text_len,
)
)
batch = self.batch_post_processing(batch, video_feature)
batch = default_data_collator(batch)
else:
batch = self.sampling(video_id, video_feature, text_feature)
batch = self.batch_post_processing(batch, video_feature)
batch["video_id"] = video_id if isinstance(video_id, str) \
else video_id[0]
return batch
def sampling(
self,
video_id,
video_feature,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
):
output = FixedLenAligner.sampling(self,
video_id, video_feature, text_feature,
centerclip_idx, sampled_max_text_len)
masking_text, masking_video = None, None
if random.random() < self.mm_prob:
if random.random() > 0.5:
masking_text, masking_video = self.mm_type, "no"
else:
masking_text, masking_video = "no", "full"
video_feats = output["vfeats"] if not self.lazy_vfeat_mask else None
video_label = self.masked_frame_sampler(
output["vmasks"], masking_video, vfeats=video_feats)
caps, text_label = self.masked_token_sampler(
output["caps"], masking_text)
output.update({
"caps": caps,
"video_label": video_label,
"text_label": text_label,
})
if self.attnmasker is not None:
attention_mask = self.attnmasker(
output["vmasks"], output["cmasks"], masking_text)
output.update({
"attention_mask": attention_mask
})
return output
class FrameMaskingProcessor(Processor):
def __init__(self, config):
self.mfm_probability = 0.15
if config.mfm_probability is not None:
self.mfm_probability = config.mfm_probability
def __call__(self, vmasks, modality_masking=None, vfeats=None):
"""
We perform lazy masking to save data transfer time.
It only generates video_labels by default and MFM model
will do actualy masking.
Return: `video_label` is a binary mask.
"""
video_label = vmasks.clone()
if modality_masking is not None:
if modality_masking == "full":
probability_matrix = torch.full(video_label.shape, 1.)
elif modality_masking == "no":
probability_matrix = torch.full(video_label.shape, 0.)
elif modality_masking == "inverse":
probability_matrix = torch.full(
video_label.shape, 1. - self.mfm_probability)
else:
raise ValueError("unknown modality masking.", modality_masking)
else:
probability_matrix = torch.full(
video_label.shape, self.mfm_probability)
masked_indices = torch.bernoulli(probability_matrix).bool()
# We only compute loss on masked tokens
video_label[~masked_indices] = 0
if vfeats is not None:
vfeats[video_label, :] = 0.0
return video_label
class TextGenerationProcessor(Processor):
def __init__(self, tokenizer):
self.bos_token_id = tokenizer.bos_token_id
self.pad_token_id = tokenizer.pad_token_id
def __call__(self, inputs):
labels = inputs.clone()
# [CLS] [SEP] for video
labels[:2] = -100
# keep [SEP] for text.
pad_mask = labels == self.pad_token_id
labels[pad_mask] = -100
inputs[2:] = torch.cat([
torch.LongTensor([self.bos_token_id]),
inputs[2:-1]])
inputs[pad_mask] = self.pad_token_id
assert len(inputs) == len(labels)
return inputs, labels
class TextMaskingProcessor(Processor):
def __init__(self, config):
"""this function is borrowed from
`transformers/data/data_collator.DataCollatorForLanguageModeling`"""
self.mlm_probability = 0.15
if config.mlm_probability is not None:
self.mlm_probability = config.mlm_probability
self.bert_name = config.bert_name
# [CLS] is used as bos_token and [SEP] is used as eos_token.
# https://huggingface.co/transformers/master/model_doc/bertgeneration.html
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
self.bert_name, bos_token="[CLS]", eos_token="[SEP]")
self.textgen = TextGenerationProcessor(self.tokenizer)
def __call__(
self, inputs: torch.Tensor,
modality_masking=None,
special_tokens_mask: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
expand modality_masking into
None: traditional bert masking.
"no": no masking.
"full": all [MASK] token for generation.
"gen": autoregressive generation.
"""
"""
Prepare masked tokens inputs/labels for masked language modeling:
80% MASK, 10% random, 10% original.
"""
labels = inputs.clone()
# We sample a few tokens in each sequence for MLM training
# (with probability `self.mlm_probability`)
if modality_masking is not None:
if modality_masking == "full":
probability_matrix = torch.full(labels.shape, 1.)
elif modality_masking == "no":
probability_matrix = torch.full(labels.shape, 0.)
elif modality_masking.startswith("textgen"):
# [CLS] [SEP] <s> ...
inputs, labels = self.textgen(inputs)
if "mask" not in modality_masking:
return inputs, labels
inputs = self.mask_input(inputs, special_tokens_mask)
return inputs, labels
elif modality_masking == "mask":
inputs = self.mask_input(inputs, special_tokens_mask)
labels = torch.full(inputs.shape, -100)
return inputs, labels
elif modality_masking == "inverse":
probability_matrix = torch.full(labels.shape, 1. - self.mlm_probability)
else:
raise ValueError("unknown modality masking.", modality_masking)
else:
probability_matrix = torch.full(labels.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = self.get_special_tokens_mask(
labels.tolist(), already_has_special_tokens=True
)
special_tokens_mask = torch.tensor(
special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time,
# we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = (
torch.bernoulli(
torch.full(labels.shape, 0.8)).bool() & masked_indices
)
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.mask_token
)
# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(labels.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(
len(self.tokenizer), labels.shape, dtype=torch.long
)
inputs[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input
# tokens unchanged
return inputs, labels
def mask_input(self, inputs, special_tokens_mask=None):
# the following is new with masked autoregressive.
probability_matrix = torch.full(
inputs.shape, self.mlm_probability)
if special_tokens_mask is None:
special_tokens_mask = self.get_special_tokens_mask(
inputs.tolist(), already_has_special_tokens=True
)
special_tokens_mask = torch.tensor(
special_tokens_mask, dtype=torch.bool)
else:
special_tokens_mask = special_tokens_mask.bool()
probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
indices_replaced = (
torch.bernoulli(
torch.full(inputs.shape, 0.8)).bool() & masked_indices
)
inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.mask_token
)
# 10% of the time, we replace masked input tokens with random word
indices_random = (
torch.bernoulli(torch.full(inputs.shape, 0.5)).bool()
& masked_indices
& ~indices_replaced
)
random_words = torch.randint(
len(self.tokenizer), inputs.shape, dtype=torch.long
)
inputs[indices_random] = random_words[indices_random]
return inputs
def get_special_tokens_mask(
self, token_ids_0: List[int],
token_ids_1: Optional[List[int]] = None,
already_has_special_tokens: bool = False
) -> List[int]:
"""
Note: the version from transformers do not consider pad
as special tokens.
"""
if already_has_special_tokens:
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if"
"the provided sequence of "
"ids is already formated with special tokens "
"for the model."
)
return list(map(lambda x: 1 if x in [
self.tokenizer.sep_token_id,
self.tokenizer.cls_token_id,
self.tokenizer.pad_token_id] else 0, token_ids_0))
if token_ids_1 is not None:
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
return [1] + ([0] * len(token_ids_0)) + [1]
class TextClipSamplingProcessor(Processor):
def __init__(self, max_text_len, keep_prob=1.0):
self.max_text_len = max_text_len
self.max_video_len = 256 # always hold.
self.keep_prob = keep_prob
def __call__(
self,
text_feature,
centerclip_idx=None,
sampled_max_text_len=None,
sampled_max_video_len=None,
):
# Let's use all caps for now and see if 256 can cover all of them.
if sampled_max_text_len is not None:
max_text_len = sampled_max_text_len
else:
max_text_len = self.max_text_len
if sampled_max_video_len is not None:
max_video_len = sampled_max_video_len
else:
max_video_len = self.max_video_len
t_num_clips = len(text_feature["start"])
if centerclip_idx is None:
centerclip_idx = random.randint(0, t_num_clips - 1)
start_idx, end_idx = centerclip_idx, centerclip_idx + 1
text_clip_indexs = deque()
text_clip_indexs.append(start_idx)
text_len = len(text_feature["cap"][start_idx])
video_len = max(
0,
text_feature["end"][start_idx]
- text_feature["start"][start_idx],
)
while (
(start_idx > 0 or end_idx < t_num_clips)
and text_len < max_text_len
and video_len < max_video_len
):
if random.random() > 0.5 and end_idx < t_num_clips:
# skip the next one?
if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
end_idx = end_idx + 1
text_clip_indexs.append(end_idx)
text_len += len(text_feature["cap"][end_idx])
end_idx += 1
elif start_idx > 0:
if random.random() > self.keep_prob and (start_idx - 1) > 0:
start_idx = start_idx - 1
start_idx -= 1
text_clip_indexs.insert(0, start_idx)
text_len += len(text_feature["cap"][start_idx])
else:
if end_idx < t_num_clips:
if random.random() > self.keep_prob and (end_idx + 1) < t_num_clips:
end_idx = end_idx + 1
text_clip_indexs.append(end_idx)
text_len += len(text_feature["cap"][end_idx])
end_idx += 1
else:
return text_clip_indexs
video_len = max(
0,
text_feature["end"][text_clip_indexs[-1]]
- text_feature["start"][text_clip_indexs[0]],
)
return text_clip_indexs
class VideoClipSamplingProcessor(Processor):
def __call__(self, video_len, max_video_len, center):
"""
`video_len`: length of the video.
`max_video_len`: maximum video tokens allowd in a sequence.
`center`: initial starting index.
"""
assert center >= 0 and center < video_len
t_clip_len = 0
start, end = center, center
while (start > 0 or end < video_len) and t_clip_len < max_video_len:
# decide the direction to grow.
if start <= 0:
end += 1
elif end >= video_len:
start -= 1
elif random.random() > 0.5:
end += 1
else:
start -= 1
t_clip_len += 1
return {"start": [start], "end": [end]}
class How2MILNCEAligner(FixedLenAligner):
"""reference: `antoine77340/MIL-NCE_HowTo100M/video_loader.py`"""
def __init__(self, config):
super().__init__(config)
self.num_candidates = 4
self.min_time = 5.0
self.num_sec = 3.2
# self.num_sec = self.num_frames / float(self.fps) num_frames=16 / fps = 5
# self.num_frames = 16
def sampling(
self,
video_id,
video_feature,
text_feature,
centerclip_idx=None, # will be ignored.
sampled_max_text_len=None # will be ignored.
):
text, start, end = self._get_text(text_feature)
video = self._get_video(video_feature, start, end)
vfeats = torch.zeros((self.max_video_len, video_feature.shape[1]))
vmasks = torch.zeros((self.max_video_len,), dtype=torch.bool)
vfeats[: video.shape[0]] = torch.from_numpy(np.array(video))
vmasks[: video.shape[0]] = 1
caps, cmasks = [], []
for words in text:
cap, cmask = self._build_text_seq(text_feature, words)
caps.append(cap)
cmasks.append(cmask)
caps = torch.stack(caps)
cmasks = torch.stack(cmasks)
# video of shape: (video_len)
# text of shape (num_candidates, max_text_len)
return {
"caps": caps,
"cmasks": cmasks,
"vfeats": vfeats,
"vmasks": vmasks,
# "video_id": video_id,
}
def _get_video(self, video_feature, start, end):
start_seek = random.randint(start, int(max(start, end - self.num_sec)))
# duration = self.num_sec + 0.1
return video_feature[start_seek : int(start_seek + self.num_sec)]
def _get_text(self, cap):
ind = random.randint(0, len(cap["start"]) - 1)
if self.num_candidates == 1:
words = [ind]
else:
words = []
cap_start = self._find_nearest_candidates(cap, ind)
for i in range(self.num_candidates):
words.append([max(0, min(len(cap["cap"]) - 1, cap_start + i))])
start, end = cap["start"][ind], cap["end"][ind]
# TODO: May need to be improved for edge cases.
# expand the min time.
if end - start < self.min_time:
diff = self.min_time - end + start
start = max(0, start - diff / 2)
end = start + self.min_time
return words, int(start), int(end)
def _find_nearest_candidates(self, caption, ind):
"""find the range of the clips."""
start, end = ind, ind
#diff = caption["end"][end] - caption["start"][start]
n_candidate = 1
while n_candidate < self.num_candidates:
# the first clip
if start == 0:
return 0
# we add () in the following condition to fix the bug.
elif end == (len(caption["start"]) - 1):
return start - (self.num_candidates - n_candidate)
elif (caption["end"][end] - caption["start"][start - 1]) < (
caption["end"][end + 1] - caption["start"][start]
):
start -= 1
else:
end += 1
n_candidate += 1
return start
class PKLJSONStrTextProcessor(TextProcessor):
"""`caption.json` from howto100m are preprocessed as a
dict `[video_id, json_str]`.
Json parsing tokenization are conducted on-the-fly and cached into dict.
"""
def __init__(self, config, max_clip_text_len=96):
print("[Warning] PKLJSONStrTextProcessor is slow for num_workers > 0.")
self.caption_pkl_path = str(config.caption_pkl_path)
with open(self.caption_pkl_path, "rb") as fd:
self.data = pickle.load(fd)
self.max_clip_text_len = max_clip_text_len
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(
str(config.bert_name), use_fast=config.use_fast
)
def __call__(self, video_id):
caption = self.data[video_id]
if isinstance(caption, str):
import json
caption = json.loads(caption)
cap = []
for clip_idx, text_clip in enumerate(caption["text"]):
clip_ids = []
if isinstance(text_clip, str):
clip_ids = self.tokenizer(
text_clip[: self.max_clip_text_len],
add_special_tokens=False
)["input_ids"]
cap.append(clip_ids)
caption["cap"] = cap
caption.pop("text") # save space.
self.data[video_id] = caption
return caption