|
|
|
|
|
|
|
|
|
import os |
|
import numpy as np |
|
import pickle |
|
import time |
|
|
|
try: |
|
import faiss |
|
except ImportError: |
|
pass |
|
|
|
from collections import defaultdict |
|
|
|
from ..utils import get_local_rank, print_on_rank0 |
|
|
|
|
|
class VectorRetriever(object): |
|
""" |
|
How2 Video Retriver. |
|
Reference usage of FAISS: |
|
https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py |
|
""" |
|
|
|
def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train): |
|
if db_type == "flatl2": |
|
quantizer = faiss.IndexFlatL2(hidden_size) |
|
self.db = faiss.IndexIVFFlat( |
|
quantizer, hidden_size, cent, faiss.METRIC_L2) |
|
elif db_type == "pq": |
|
self.db = faiss.index_factory( |
|
hidden_size, f"IVF{cent}_HNSW32,PQ32" |
|
) |
|
else: |
|
raise ValueError("unknown type of db", db_type) |
|
self.train_thres = cent * examples_per_cent_to_train |
|
self.train_cache = [] |
|
self.train_len = 0 |
|
self.videoid_to_vectoridx = {} |
|
self.vectoridx_to_videoid = None |
|
self.make_direct_maps_done = False |
|
|
|
def make_direct_maps(self): |
|
faiss.downcast_index(self.db).make_direct_map() |
|
|
|
def __len__(self): |
|
return self.db.ntotal |
|
|
|
def save(self, out_dir): |
|
faiss.write_index( |
|
self.db, |
|
os.path.join(out_dir, "faiss_idx") |
|
) |
|
with open( |
|
os.path.join( |
|
out_dir, "videoid_to_vectoridx.pkl"), |
|
"wb") as fw: |
|
pickle.dump( |
|
self.videoid_to_vectoridx, fw, |
|
protocol=pickle.HIGHEST_PROTOCOL |
|
) |
|
|
|
def load(self, out_dir): |
|
fn = os.path.join(out_dir, "faiss_idx") |
|
self.db = faiss.read_index(fn) |
|
with open( |
|
os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr: |
|
self.videoid_to_vectoridx = pickle.load(fr) |
|
|
|
def add(self, hidden_states, video_ids, last=False): |
|
assert len(hidden_states) == len(video_ids), "{}, {}".format( |
|
str(len(hidden_states)), str(len(video_ids))) |
|
assert len(hidden_states.shape) == 2 |
|
assert hidden_states.dtype == np.float32 |
|
|
|
valid_idx = [] |
|
for idx, video_id in enumerate(video_ids): |
|
if video_id not in self.videoid_to_vectoridx: |
|
valid_idx.append(idx) |
|
self.videoid_to_vectoridx[video_id] = \ |
|
len(self.videoid_to_vectoridx) |
|
|
|
hidden_states = hidden_states[valid_idx] |
|
if not self.db.is_trained: |
|
self.train_cache.append(hidden_states) |
|
self.train_len += hidden_states.shape[0] |
|
if self.train_len < self.train_thres: |
|
return |
|
self.finalize_training() |
|
else: |
|
self.db.add(hidden_states) |
|
|
|
def finalize_training(self): |
|
hidden_states = np.concatenate(self.train_cache, axis=0) |
|
del self.train_cache |
|
local_rank = get_local_rank() |
|
if local_rank == 0: |
|
start = time.time() |
|
print("training db on", self.train_thres, "/", self.train_len) |
|
self.db.train(hidden_states[:self.train_thres]) |
|
if local_rank == 0: |
|
print("training db for", time.time() - start) |
|
self.db.add(hidden_states) |
|
|
|
def search( |
|
self, |
|
query_hidden_states, |
|
orig_dist, |
|
): |
|
if len(self.videoid_to_vectoridx) != self.db.ntotal: |
|
raise ValueError( |
|
"cannot search: size mismatch in-between index and db", |
|
len(self.videoid_to_vectoridx), |
|
self.db.ntotal |
|
) |
|
|
|
if self.vectoridx_to_videoid is None: |
|
self.vectoridx_to_videoid = { |
|
self.videoid_to_vectoridx[videoid]: videoid |
|
for videoid in self.videoid_to_vectoridx |
|
} |
|
assert len(self.vectoridx_to_videoid) \ |
|
== len(self.videoid_to_vectoridx) |
|
|
|
|
|
|
|
queried_dist, index = self.db.search(query_hidden_states, 1) |
|
queried_dist, index = queried_dist[:, 0], index[:, 0] |
|
|
|
outputs = np.array( |
|
[self.vectoridx_to_videoid[_index] |
|
if _index != -1 else (-1, -1, -1) for _index in index], |
|
dtype=np.int32) |
|
outputs[queried_dist <= orig_dist] = -1 |
|
return outputs |
|
|
|
def search_by_video_ids( |
|
self, |
|
video_ids, |
|
retri_factor |
|
): |
|
if len(self.videoid_to_vectoridx) != self.db.ntotal: |
|
raise ValueError( |
|
len(self.videoid_to_vectoridx), |
|
self.db.ntotal |
|
) |
|
|
|
if not self.make_direct_maps_done: |
|
self.make_direct_maps() |
|
|
|
if self.vectoridx_to_videoid is None: |
|
self.vectoridx_to_videoid = { |
|
self.videoid_to_vectoridx[videoid]: videoid |
|
for videoid in self.videoid_to_vectoridx |
|
} |
|
assert len(self.vectoridx_to_videoid) \ |
|
== len(self.videoid_to_vectoridx) |
|
|
|
query_hidden_states = [] |
|
vector_ids = [] |
|
for video_id in video_ids: |
|
vector_id = self.videoid_to_vectoridx[video_id] |
|
vector_ids.append(vector_id) |
|
query_hidden_state = self.db.reconstruct(vector_id) |
|
query_hidden_states.append(query_hidden_state) |
|
query_hidden_states = np.stack(query_hidden_states) |
|
|
|
|
|
|
|
_, index = self.db.search(query_hidden_states, retri_factor) |
|
outputs = [] |
|
for sample_idx, sample in enumerate(index): |
|
|
|
cands = [video_ids[sample_idx]] |
|
for vector_idx in sample: |
|
if vector_idx >= 0 \ |
|
and vector_ids[sample_idx] != vector_idx: |
|
cands.append( |
|
self.vectoridx_to_videoid[vector_idx] |
|
) |
|
outputs.append(cands) |
|
return outputs |
|
|
|
|
|
class VectorRetrieverDM(VectorRetriever): |
|
""" |
|
with direct map. |
|
How2 Video Retriver. |
|
Reference usage of FAISS: |
|
https://github.com/fairinternal/fairseq-py/blob/paraphrase_pretraining/fairseq/data/multilingual_faiss_dataset.py |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size, |
|
cent, |
|
db_type, |
|
examples_per_cent_to_train |
|
): |
|
super().__init__( |
|
hidden_size, cent, db_type, examples_per_cent_to_train) |
|
self.make_direct_maps_done = False |
|
|
|
def make_direct_maps(self): |
|
faiss.downcast_index(self.db).make_direct_map() |
|
self.make_direct_maps_done = True |
|
|
|
def search( |
|
self, |
|
query_hidden_states, |
|
orig_dist, |
|
): |
|
if len(self.videoid_to_vectoridx) != self.db.ntotal: |
|
raise ValueError( |
|
len(self.videoid_to_vectoridx), |
|
self.db.ntotal |
|
) |
|
|
|
if not self.make_direct_maps_done: |
|
self.make_direct_maps() |
|
if self.vectoridx_to_videoid is None: |
|
self.vectoridx_to_videoid = { |
|
self.videoid_to_vectoridx[videoid]: videoid |
|
for videoid in self.videoid_to_vectoridx |
|
} |
|
assert len(self.vectoridx_to_videoid) \ |
|
== len(self.videoid_to_vectoridx) |
|
|
|
|
|
|
|
queried_dist, index = self.db.search(query_hidden_states, 1) |
|
outputs = [] |
|
for sample_idx, sample in enumerate(index): |
|
|
|
if sample >= 0 \ |
|
and queried_dist[sample_idx] < orig_dist[sample_idx]: |
|
outputs.append(self.vectoridx_to_videoid[sample]) |
|
else: |
|
outputs.append(None) |
|
return outputs |
|
|
|
def search_by_video_ids( |
|
self, |
|
video_ids, |
|
retri_factor=8 |
|
): |
|
if len(self.videoid_to_vectoridx) != self.db.ntotal: |
|
raise ValueError( |
|
len(self.videoid_to_vectoridx), |
|
self.db.ntotal |
|
) |
|
|
|
if not self.make_direct_maps_done: |
|
self.make_direct_maps() |
|
if self.vectoridx_to_videoid is None: |
|
self.vectoridx_to_videoid = { |
|
self.videoid_to_vectoridx[videoid]: videoid |
|
for videoid in self.videoid_to_vectoridx |
|
} |
|
assert len(self.vectoridx_to_videoid) \ |
|
== len(self.videoid_to_vectoridx) |
|
|
|
query_hidden_states = [] |
|
vector_ids = [] |
|
for video_id in video_ids: |
|
vector_id = self.videoid_to_vectoridx[video_id] |
|
vector_ids.append(vector_id) |
|
query_hidden_state = self.db.reconstruct(vector_id) |
|
query_hidden_states.append(query_hidden_state) |
|
query_hidden_states = np.stack(query_hidden_states) |
|
|
|
|
|
|
|
_, index = self.db.search(query_hidden_states, retri_factor) |
|
outputs = [] |
|
for sample_idx, sample in enumerate(index): |
|
|
|
cands = [video_ids[sample_idx]] |
|
for vector_idx in sample: |
|
if vector_idx >= 0 \ |
|
and vector_ids[sample_idx] != vector_idx: |
|
cands.append( |
|
self.vectoridx_to_videoid[vector_idx] |
|
) |
|
outputs.append(cands) |
|
return outputs |
|
|
|
|
|
class MMVectorRetriever(VectorRetrieverDM): |
|
""" |
|
multimodal vector retriver: |
|
text retrieve video or video retrieve text. |
|
""" |
|
|
|
def __init__(self, hidden_size, cent, db_type, examples_per_cent_to_train): |
|
super().__init__( |
|
hidden_size, cent, db_type, examples_per_cent_to_train) |
|
video_db = self.db |
|
super().__init__( |
|
hidden_size, cent, db_type, examples_per_cent_to_train) |
|
text_db = self.db |
|
self.db = {"video": video_db, "text": text_db} |
|
self.video_to_videoid = defaultdict(list) |
|
|
|
def __len__(self): |
|
assert self.db["video"].ntotal == self.db["text"].ntotal |
|
return self.db["video"].ntotal |
|
|
|
def make_direct_maps(self): |
|
faiss.downcast_index(self.db["video"]).make_direct_map() |
|
faiss.downcast_index(self.db["text"]).make_direct_map() |
|
|
|
def save(self, out_dir): |
|
faiss.write_index( |
|
self.db["video"], |
|
os.path.join(out_dir, "video_faiss_idx") |
|
) |
|
faiss.write_index( |
|
self.db["text"], |
|
os.path.join(out_dir, "text_faiss_idx") |
|
) |
|
|
|
with open( |
|
os.path.join( |
|
out_dir, "videoid_to_vectoridx.pkl"), |
|
"wb") as fw: |
|
pickle.dump( |
|
self.videoid_to_vectoridx, fw, |
|
protocol=pickle.HIGHEST_PROTOCOL |
|
) |
|
|
|
def load(self, out_dir): |
|
fn = os.path.join(out_dir, "video_faiss_idx") |
|
video_db = faiss.read_index(fn) |
|
fn = os.path.join(out_dir, "text_faiss_idx") |
|
text_db = faiss.read_index(fn) |
|
self.db = {"video": video_db, "text": text_db} |
|
with open( |
|
os.path.join(out_dir, "videoid_to_vectoridx.pkl"), "rb") as fr: |
|
self.videoid_to_vectoridx = pickle.load(fr) |
|
self.video_to_videoid = defaultdict(list) |
|
|
|
def add(self, hidden_states, video_ids): |
|
"""hidden_states is a pair `(video, text)`""" |
|
assert len(hidden_states) == len(video_ids), "{}, {}".format( |
|
str(len(hidden_states)), str(len(video_ids))) |
|
assert len(hidden_states.shape) == 3 |
|
assert len(self.video_to_videoid) == 0 |
|
|
|
valid_idx = [] |
|
for idx, video_id in enumerate(video_ids): |
|
if video_id not in self.videoid_to_vectoridx: |
|
valid_idx.append(idx) |
|
self.videoid_to_vectoridx[video_id] = \ |
|
len(self.videoid_to_vectoridx) |
|
|
|
batch_size = hidden_states.shape[0] |
|
hidden_states = hidden_states[valid_idx] |
|
|
|
hidden_states = np.transpose(hidden_states, (1, 0, 2)).copy() |
|
if not self.db["video"].is_trained: |
|
self.train_cache.append(hidden_states) |
|
train_len = batch_size * len(self.train_cache) |
|
if train_len < self.train_thres: |
|
return |
|
|
|
hidden_states = np.concatenate(self.train_cache, axis=1) |
|
del self.train_cache |
|
self.db["video"].train(hidden_states[0, :self.train_thres]) |
|
self.db["text"].train(hidden_states[1, :self.train_thres]) |
|
self.db["video"].add(hidden_states[0]) |
|
self.db["text"].add(hidden_states[1]) |
|
|
|
def get_clips_by_video_id(self, video_id): |
|
if not self.video_to_videoid: |
|
for video_id, video_clip, text_clip in self.videoid_to_vectoridx: |
|
self.video_to_videoid[video_id].append( |
|
(video_id, video_clip, text_clip)) |
|
return self.video_to_videoid[video_id] |
|
|
|
def search( |
|
self, |
|
video_ids, |
|
target_modality, |
|
retri_factor=8 |
|
): |
|
if len(self.videoid_to_vectoridx) != len(self): |
|
raise ValueError( |
|
len(self.videoid_to_vectoridx), |
|
len(self) |
|
) |
|
|
|
if not self.make_direct_maps_done: |
|
self.make_direct_maps() |
|
if self.vectoridx_to_videoid is None: |
|
self.vectoridx_to_videoid = { |
|
self.videoid_to_vectoridx[videoid]: videoid |
|
for videoid in self.videoid_to_vectoridx |
|
} |
|
assert len(self.vectoridx_to_videoid) \ |
|
== len(self.videoid_to_vectoridx) |
|
|
|
src_modality = "text" if target_modality == "video" else "video" |
|
|
|
query_hidden_states = [] |
|
vector_ids = [] |
|
for video_id in video_ids: |
|
vector_id = self.videoid_to_vectoridx[video_id] |
|
vector_ids.append(vector_id) |
|
query_hidden_state = self.db[src_modality].reconstruct(vector_id) |
|
query_hidden_states.append(query_hidden_state) |
|
query_hidden_states = np.stack(query_hidden_states) |
|
|
|
|
|
|
|
_, index = self.db[target_modality].search( |
|
query_hidden_states, retri_factor) |
|
outputs = [] |
|
for sample_idx, sample in enumerate(index): |
|
cands = [] |
|
for vector_idx in sample: |
|
if vector_idx >= 0: |
|
cands.append( |
|
self.vectoridx_to_videoid[vector_idx] |
|
) |
|
outputs.append(cands) |
|
return outputs |
|
|