# Copyright (c) Facebook, Inc. All Rights Reserved import torch import os import numpy as np import pickle from . import retri from ..utils import get_local_rank class VectorPool(object): """ Base class of retrieval space. """ def __init__(self, config): from transformers import AutoConfig self.hidden_size = AutoConfig.from_pretrained( config.dataset.bert_name).hidden_size self.retriever_cls = getattr(retri, config.retriever_cls) def __call__(self, sample, **kwargs): raise NotImplementedError def build_retriver( self, retriever_cls=None, hidden_size=None, centroids=512, db_type="flatl2", examples_per_cent_to_train=48 ): """merge results from multiple gpus and return a retriver..""" self.retriver = retriever_cls( hidden_size, centroids, db_type, examples_per_cent_to_train) return self.retriver def __repr__(self): if hasattr(self, "retriver"): retriver_name = str(len(self.retriver)) else: retriver_name = "no retriver field yet" return self.__class__.__name__ \ + "(" + retriver_name + ")" class VideoVectorPool(VectorPool): """ average clips of a video as video representation. """ def __init__(self, config): super().__init__(config) self.build_retriver(self.retriever_cls, self.hidden_size) def __call__(self, sample, subsampling, **kwargs): hidden_states = ( sample["pooled_video"] + sample["pooled_text"]) / 2. hidden_states = hidden_states.view( -1, subsampling, hidden_states.size(-1)) hidden_states = torch.mean(hidden_states, dim=1) hidden_states = hidden_states.cpu().detach().numpy() video_ids = [] for offset_idx, video_id in enumerate(sample["video_id"]): if isinstance(video_id, tuple) and len(video_id) == 3: # a sharded video_id. video_id = video_id[0] video_ids.append(video_id) assert len(video_ids) == len(hidden_states) self.retriver.add( hidden_states.astype("float32"), video_ids ) class DistributedVectorPool(VectorPool): """ support sync of multiple gpus/nodes. """ def __init__(self, config): super().__init__(config) self.out_dir = os.path.join( config.fairseq.checkpoint.save_dir, "retri") os.makedirs(self.out_dir, exist_ok=True) self.hidden_states = [] self.video_ids = [] def build_retriver( self, retriever_cls=None, hidden_size=None, centroids=4096, db_type="flatl2", examples_per_cent_to_train=48 ): if retriever_cls is None: retriever_cls = self.retriever_cls if hidden_size is None: hidden_size = self.hidden_size """merge results from multiple gpus and return a retriver..""" if torch.distributed.is_initialized(): self.save() # sync saving. torch.distributed.barrier() world_size = torch.distributed.get_world_size() else: world_size = 1 self.retriver = retriever_cls( hidden_size, centroids, db_type, examples_per_cent_to_train) # each gpu process has its own retriever. for local_rank in range(world_size): if get_local_rank() == 0: print("load local_rank", local_rank) hidden_states, video_ids = self.load(local_rank) hidden_states = hidden_states.astype("float32") self.retriver.add(hidden_states, video_ids) return self.retriver def load(self, local_rank): hidden_states = np.load( os.path.join( self.out_dir, "hidden_state" + str(local_rank) + ".npy" ) ) with open( os.path.join( self.out_dir, "video_id" + str(local_rank) + ".pkl"), "rb") as fr: video_ids = pickle.load(fr) return hidden_states, video_ids def save(self): hidden_states = np.vstack(self.hidden_states) assert len(hidden_states) == len(self.video_ids), "{}, {}".format( len(hidden_states), len(self.video_ids) ) local_rank = torch.distributed.get_rank() \ if torch.distributed.is_initialized() else 0 np.save( os.path.join( self.out_dir, "hidden_state" + str(local_rank) + ".npy"), hidden_states) with open( os.path.join( self.out_dir, "video_id" + str(local_rank) + ".pkl"), "wb") as fw: pickle.dump( self.video_ids, fw, protocol=pickle.HIGHEST_PROTOCOL ) class DistributedVideoVectorPool(DistributedVectorPool): """ average clips of a video as video representation. """ def __call__(self, sample, subsampling, **kwargs): hidden_states = ( sample["pooled_video"] + sample["pooled_text"]) / 2. hidden_states = hidden_states.view( -1, subsampling, hidden_states.size(-1)) hidden_states = torch.mean(hidden_states, dim=1) hidden_states = hidden_states.cpu().detach().numpy() video_ids = [] for offset_idx, video_id in enumerate(sample["video_id"]): if isinstance(video_id, tuple) and len(video_id) == 3: # a sharded video_id. video_id = video_id[0] video_ids.append(video_id) assert len(video_ids) == len(hidden_states) self.hidden_states.append(hidden_states) self.video_ids.extend(video_ids) # ------------ the following are deprecated -------------- class TextClipVectorPool(VectorPool): def __init__(self, config): from transformers import AutoConfig hidden_size = AutoConfig.from_pretrained( config.dataset.bert_name).hidden_size retriever_cls = getattr(retri, config.retriever_cls) self.build_retriver(retriever_cls, hidden_size) def __call__(self, sample, **kwargs): clip_meta = sample["clip_meta"].cpu() assert torch.all(torch.le(clip_meta[:, 4], clip_meta[:, 5])) text_meta = [tuple(item.tolist()) for item in clip_meta[:, 3:]] if hasattr(self, "retriver"): # build_retriver is called. self.retriver.add( sample["pooled_text"].cpu().numpy().astype("float32"), text_meta ) else: raise NotImplementedError class MMClipVectorPool(VectorPool): """ Multimodal Clip-level vector pool. """ def __init__(self, out_dir): """use hidden_states to store `(video, text)`.""" """use video_ids to store `(video_id, start, end)`.""" super().__init__(out_dir) def __call__(self, sample, **kwargs): pooled_video = sample["pooled_video"].cpu().unsqueeze(1).numpy() pooled_text = sample["pooled_text"].cpu().unsqueeze(1).numpy() self.hidden_states.append( np.concatenate([pooled_video, pooled_text], axis=1) ) video_starts = sample["video_start"].cpu() video_ends = sample["video_end"].cpu() assert torch.all(torch.le(video_starts, video_ends)) text_starts = sample["text_start"].cpu() text_ends = sample["text_end"].cpu() assert torch.all(torch.le(text_starts, text_ends)) subsample_size = sample["pooled_video"].size(0) // len(sample["video_id"]) video_ids = [video_id for video_id in sample["video_id"] for _ in range(subsample_size) ] for video_id, video_start, video_end, text_start, text_end in zip( video_ids, video_starts, video_ends, text_starts, text_ends): self.video_ids.append(( video_id, (int(video_start), int(video_end)), (int(text_start), int(text_end)) ))