# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import numpy as np import pickle import time try: import faiss except ImportError: pass from collections import defaultdict from ..utils import get_local_rank, print_on_rank0 class VectorRetriever(object): """ How2 Video Retriver. Reference usage of FAISS: https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py """ def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train): if db_type == "flatl2": quantizer = faiss.IndexFlatL2(hidden_size) # the other index self.db = faiss.IndexIVFFlat( quantizer, hidden_size, cent, faiss.METRIC_L2) elif db_type == "pq": self.db = faiss.index_factory( hidden_size, f"IVF{cent}_HNSW32,PQ32" ) else: raise ValueError("unknown type of db", db_type) self.train_thres = cent * examples_per_cent_to_train self.train_cache = [] self.train_len = 0 self.videoid_to_vectoridx = {} self.vectoridx_to_videoid = None self.make_direct_maps_done = False def make_direct_maps(self): faiss.downcast_index(self.db).make_direct_map() def __len__(self): return self.db.ntotal def save(self, out_dir): faiss.write_index( self.db, os.path.join(out_dir, "faiss_idx") ) with open( os.path.join( out_dir, "videoid_to_vectoridx.pkl"), "wb") as fw: pickle.dump( self.videoid_to_vectoridx, fw, protocol=pickle.HIGHEST_PROTOCOL ) def load(self, out_dir): fn = os.path.join(out_dir, "faiss_idx") self.db = faiss.read_index(fn) with open( os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr: self.videoid_to_vectoridx = pickle.load(fr) def add(self, hidden_states, video_ids, last=False): assert len(hidden_states) == len(video_ids), "{}, {}".format( str(len(hidden_states)), str(len(video_ids))) assert len(hidden_states.shape) == 2 assert hidden_states.dtype == np.float32 valid_idx = [] for idx, video_id in enumerate(video_ids): if video_id not in self.videoid_to_vectoridx: valid_idx.append(idx) self.videoid_to_vectoridx[video_id] = \ len(self.videoid_to_vectoridx) hidden_states = hidden_states[valid_idx] if not self.db.is_trained: self.train_cache.append(hidden_states) self.train_len += hidden_states.shape[0] if self.train_len < self.train_thres: return self.finalize_training() else: self.db.add(hidden_states) def finalize_training(self): hidden_states = np.concatenate(self.train_cache, axis=0) del self.train_cache local_rank = get_local_rank() if local_rank == 0: start = time.time() print("training db on", self.train_thres, "/", self.train_len) self.db.train(hidden_states[:self.train_thres]) if local_rank == 0: print("training db for", time.time() - start) self.db.add(hidden_states) def search( self, query_hidden_states, orig_dist, ): if len(self.videoid_to_vectoridx) != self.db.ntotal: raise ValueError( "cannot search: size mismatch in-between index and db", len(self.videoid_to_vectoridx), self.db.ntotal ) if self.vectoridx_to_videoid is None: self.vectoridx_to_videoid = { self.videoid_to_vectoridx[videoid]: videoid for videoid in self.videoid_to_vectoridx } assert len(self.vectoridx_to_videoid) \ == len(self.videoid_to_vectoridx) # MultilingualFaissDataset uses the following; not sure the purpose. # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) queried_dist, index = self.db.search(query_hidden_states, 1) queried_dist, index = queried_dist[:, 0], index[:, 0] outputs = np.array( [self.vectoridx_to_videoid[_index] if _index != -1 else (-1, -1, -1) for _index in index], dtype=np.int32) outputs[queried_dist <= orig_dist] = -1 return outputs def search_by_video_ids( self, video_ids, retri_factor ): if len(self.videoid_to_vectoridx) != self.db.ntotal: raise ValueError( len(self.videoid_to_vectoridx), self.db.ntotal ) if not self.make_direct_maps_done: self.make_direct_maps() if self.vectoridx_to_videoid is None: self.vectoridx_to_videoid = { self.videoid_to_vectoridx[videoid]: videoid for videoid in self.videoid_to_vectoridx } assert len(self.vectoridx_to_videoid) \ == len(self.videoid_to_vectoridx) query_hidden_states = [] vector_ids = [] for video_id in video_ids: vector_id = self.videoid_to_vectoridx[video_id] vector_ids.append(vector_id) query_hidden_state = self.db.reconstruct(vector_id) query_hidden_states.append(query_hidden_state) query_hidden_states = np.stack(query_hidden_states) # MultilingualFaissDataset uses the following; not sure the reason. # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) _, index = self.db.search(query_hidden_states, retri_factor) outputs = [] for sample_idx, sample in enumerate(index): # the first video_id is always the video itself. cands = [video_ids[sample_idx]] for vector_idx in sample: if vector_idx >= 0 \ and vector_ids[sample_idx] != vector_idx: cands.append( self.vectoridx_to_videoid[vector_idx] ) outputs.append(cands) return outputs class VectorRetrieverDM(VectorRetriever): """ with direct map. How2 Video Retriver. Reference usage of FAISS: https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py """ def __init__( self, hidden_size, cent, db_type, examples_per_cent_to_train ): super().__init__( hidden_size, cent, db_type, examples_per_cent_to_train) self.make_direct_maps_done = False def make_direct_maps(self): faiss.downcast_index(self.db).make_direct_map() self.make_direct_maps_done = True def search( self, query_hidden_states, orig_dist, ): if len(self.videoid_to_vectoridx) != self.db.ntotal: raise ValueError( len(self.videoid_to_vectoridx), self.db.ntotal ) if not self.make_direct_maps_done: self.make_direct_maps() if self.vectoridx_to_videoid is None: self.vectoridx_to_videoid = { self.videoid_to_vectoridx[videoid]: videoid for videoid in self.videoid_to_vectoridx } assert len(self.vectoridx_to_videoid) \ == len(self.videoid_to_vectoridx) # MultilingualFaissDataset uses the following; not sure the reason. # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) queried_dist, index = self.db.search(query_hidden_states, 1) outputs = [] for sample_idx, sample in enumerate(index): # and queried_dist[sample_idx] < thres \ if sample >= 0 \ and queried_dist[sample_idx] < orig_dist[sample_idx]: outputs.append(self.vectoridx_to_videoid[sample]) else: outputs.append(None) return outputs def search_by_video_ids( self, video_ids, retri_factor=8 ): if len(self.videoid_to_vectoridx) != self.db.ntotal: raise ValueError( len(self.videoid_to_vectoridx), self.db.ntotal ) if not self.make_direct_maps_done: self.make_direct_maps() if self.vectoridx_to_videoid is None: self.vectoridx_to_videoid = { self.videoid_to_vectoridx[videoid]: videoid for videoid in self.videoid_to_vectoridx } assert len(self.vectoridx_to_videoid) \ == len(self.videoid_to_vectoridx) query_hidden_states = [] vector_ids = [] for video_id in video_ids: vector_id = self.videoid_to_vectoridx[video_id] vector_ids.append(vector_id) query_hidden_state = self.db.reconstruct(vector_id) query_hidden_states.append(query_hidden_state) query_hidden_states = np.stack(query_hidden_states) # MultilingualFaissDataset uses the following; not sure the reason. # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) _, index = self.db.search(query_hidden_states, retri_factor) outputs = [] for sample_idx, sample in enumerate(index): # the first video_id is always the video itself. cands = [video_ids[sample_idx]] for vector_idx in sample: if vector_idx >= 0 \ and vector_ids[sample_idx] != vector_idx: cands.append( self.vectoridx_to_videoid[vector_idx] ) outputs.append(cands) return outputs class MMVectorRetriever(VectorRetrieverDM): """ multimodal vector retriver: text retrieve video or video retrieve text. """ def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train): super().__init__( hidden_size, cent, db_type, examples_per_cent_to_train) video_db = self.db super().__init__( hidden_size, cent, db_type, examples_per_cent_to_train) text_db = self.db self.db = {"video": video_db, "text": text_db} self.video_to_videoid = defaultdict(list) def __len__(self): assert self.db["video"].ntotal == self.db["text"].ntotal return self.db["video"].ntotal def make_direct_maps(self): faiss.downcast_index(self.db["video"]).make_direct_map() faiss.downcast_index(self.db["text"]).make_direct_map() def save(self, out_dir): faiss.write_index( self.db["video"], os.path.join(out_dir, "video_faiss_idx") ) faiss.write_index( self.db["text"], os.path.join(out_dir, "text_faiss_idx") ) with open( os.path.join( out_dir, "videoid_to_vectoridx.pkl"), "wb") as fw: pickle.dump( self.videoid_to_vectoridx, fw, protocol=pickle.HIGHEST_PROTOCOL ) def load(self, out_dir): fn = os.path.join(out_dir, "video_faiss_idx") video_db = faiss.read_index(fn) fn = os.path.join(out_dir, "text_faiss_idx") text_db = faiss.read_index(fn) self.db = {"video": video_db, "text": text_db} with open( os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr: self.videoid_to_vectoridx = pickle.load(fr) self.video_to_videoid = defaultdict(list) def add(self, hidden_states, video_ids): """hidden_states is a pair `(video, text)`""" assert len(hidden_states) == len(video_ids), "{}, {}".format( str(len(hidden_states)), str(len(video_ids))) assert len(hidden_states.shape) == 3 assert len(self.video_to_videoid) == 0 valid_idx = [] for idx, video_id in enumerate(video_ids): if video_id not in self.videoid_to_vectoridx: valid_idx.append(idx) self.videoid_to_vectoridx[video_id] = \ len(self.videoid_to_vectoridx) batch_size = hidden_states.shape[0] hidden_states = hidden_states[valid_idx] hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy() if not self.db["video"].is_trained: self.train_cache.append(hidden_states) train_len = batch_size * len(self.train_cache) if train_len < self.train_thres: return hidden_states = np.concatenate(self.train_cache, axis=1) del self.train_cache self.db["video"].train(hidden_states[0, :self.train_thres]) self.db["text"].train(hidden_states[1, :self.train_thres]) self.db["video"].add(hidden_states[0]) self.db["text"].add(hidden_states[1]) def get_clips_by_video_id(self, video_id): if not self.video_to_videoid: for video_id, video_clip, text_clip in self.videoid_to_vectoridx: self.video_to_videoid[video_id].append( (video_id, video_clip, text_clip)) return self.video_to_videoid[video_id] def search( self, video_ids, target_modality, retri_factor=8 ): if len(self.videoid_to_vectoridx) != len(self): raise ValueError( len(self.videoid_to_vectoridx), len(self) ) if not self.make_direct_maps_done: self.make_direct_maps() if self.vectoridx_to_videoid is None: self.vectoridx_to_videoid = { self.videoid_to_vectoridx[videoid]: videoid for videoid in self.videoid_to_vectoridx } assert len(self.vectoridx_to_videoid) \ == len(self.videoid_to_vectoridx) src_modality = "text" if target_modality == "video" else "video" query_hidden_states = [] vector_ids = [] for video_id in video_ids: vector_id = self.videoid_to_vectoridx[video_id] vector_ids.append(vector_id) query_hidden_state = self.db[src_modality].reconstruct(vector_id) query_hidden_states.append(query_hidden_state) query_hidden_states = np.stack(query_hidden_states) # MultilingualFaissDataset uses the following; not sure the reason. # faiss.ParameterSpace().set_index_parameter(self.db, "nprobe", 10) _, index = self.db[target_modality].search( query_hidden_states, retri_factor) outputs = [] for sample_idx, sample in enumerate(index): cands = [] for vector_idx in sample: if vector_idx >= 0: cands.append( self.vectoridx_to_videoid[vector_idx] ) outputs.append(cands) return outputs