# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from .how2processor import ( ShardedHow2MetaProcessor, ShardedVideoProcessor, ShardedTextProcessor, VariedLenAligner, OverlappedAligner ) class ShardedHow2VideoRetriMetaProcessor(ShardedHow2MetaProcessor): def __init__(self, config): super().__init__(config) self.num_video_per_batch = config.num_video_per_batch self.cands = [ self.data[batch_offset:batch_offset + self.num_video_per_batch] for batch_offset in range(0, (len(self.data) // (8 * self.num_video_per_batch)) * 8 * self.num_video_per_batch, self.num_video_per_batch)] def __len__(self): return len(self.cands) def set_candidates(self, cands): # no changes on num of batches. print(len(self.cands), "->", len(cands)) # assert len(self.cands) == len(cands) self.cands = cands def __getitem__(self, idx): video_ids = self.cands[idx] assert isinstance(video_ids, list) sharded_video_idxs = [] for video_id in video_ids: shard_id, video_idx = self.video_id_to_shard[video_id] sharded_video_idxs.append((video_id, -1, shard_id, video_idx)) return sharded_video_idxs, sharded_video_idxs class ShardedVideoRetriVideoProcessor(ShardedVideoProcessor): """In retrival case the video_id is a list of tuples: `(shard_id, video_idx)` .""" def __call__(self, sharded_video_idxs): assert isinstance(sharded_video_idxs, list) cand_feats = [] for shared_video_idx in sharded_video_idxs: feat = super().__call__(shared_video_idx) cand_feats.append(feat) return cand_feats class ShardedVideoRetriTextProcessor(ShardedTextProcessor): """In retrival case the video_id is a list of tuples: `(shard_id, video_idx)` .""" def __call__(self, sharded_video_idxs): assert isinstance(sharded_video_idxs, list) cand_caps = [] for shared_video_idx in sharded_video_idxs: caps = super().__call__(shared_video_idx) cand_caps.append(caps) return cand_caps class VideoRetriAligner(VariedLenAligner): # Retritask will trim dim-0. def __call__(self, sharded_video_idxs, video_features, text_features): from transformers import default_data_collator batch, video_ids = [], [] for video_id, video_feature, text_feature in \ zip(sharded_video_idxs, video_features, text_features): sub_batch = super().__call__(video_id, video_feature, text_feature) batch.append(sub_batch) if isinstance(video_id, tuple): video_id = video_id[0] video_ids.append(video_id) batch = default_data_collator(batch) batch["video_id"] = video_ids return batch class VideoRetriOverlappedAligner(OverlappedAligner): # Retritask will trim dim-0. def __call__(self, sharded_video_idxs, video_features, text_features): from transformers import default_data_collator batch, video_ids = [], [] for video_id, video_feature, text_feature in \ zip(sharded_video_idxs, video_features, text_features): sub_batch = super().__call__(video_id, video_feature, text_feature) batch.append(sub_batch) if isinstance(video_id, tuple): video_id = video_id[0] video_ids.append(video_id) batch = default_data_collator(batch) batch["video_id"] = video_ids return batch