Spaces:
Runtime error
Runtime error
import os | |
import tqdm | |
import time | |
import ujson | |
import torch | |
import random | |
try: | |
import faiss | |
except ImportError as e: | |
print("WARNING: faiss must be imported for indexing") | |
import numpy as np | |
import torch.multiprocessing as mp | |
from colbert.infra.config.config import ColBERTConfig | |
import colbert.utils.distributed as distributed | |
from colbert.infra.run import Run | |
from colbert.infra.launcher import print_memory_stats | |
from colbert.modeling.checkpoint import Checkpoint | |
from colbert.data.collection import Collection | |
from colbert.indexing.collection_encoder import CollectionEncoder | |
from colbert.indexing.index_saver import IndexSaver | |
from colbert.indexing.utils import optimize_ivf | |
from colbert.utils.utils import flatten, print_message | |
from colbert.indexing.codecs.residual import ResidualCodec | |
def encode(config, collection, shared_lists, shared_queues): | |
encoder = CollectionIndexer(config=config, collection=collection) | |
encoder.run(shared_lists) | |
class CollectionIndexer(): | |
def __init__(self, config: ColBERTConfig, collection): | |
self.config = config | |
self.rank, self.nranks = self.config.rank, self.config.nranks | |
self.use_gpu = self.config.total_visible_gpus > 0 | |
if self.config.rank == 0: | |
self.config.help() | |
self.collection = Collection.cast(collection) | |
self.checkpoint = Checkpoint(self.config.checkpoint, colbert_config=self.config) | |
if self.use_gpu: | |
self.checkpoint = self.checkpoint.cuda() | |
self.encoder = CollectionEncoder(config, self.checkpoint) | |
self.saver = IndexSaver(config) | |
print_memory_stats(f'RANK:{self.rank}') | |
def run(self, shared_lists): | |
with torch.inference_mode(): | |
self.setup() | |
distributed.barrier(self.rank) | |
print_memory_stats(f'RANK:{self.rank}') | |
if not self.config.resume or not self.saver.try_load_codec(): | |
self.train(shared_lists) | |
distributed.barrier(self.rank) | |
print_memory_stats(f'RANK:{self.rank}') | |
self.index() | |
distributed.barrier(self.rank) | |
print_memory_stats(f'RANK:{self.rank}') | |
self.finalize() | |
distributed.barrier(self.rank) | |
print_memory_stats(f'RANK:{self.rank}') | |
def setup(self): | |
if self.config.resume: | |
if self._try_load_plan(): | |
Run().print_main(f"#> Loaded plan from {self.plan_path}:") | |
Run().print_main(f"#> num_chunks = {self.num_chunks}") | |
Run().print_main(f"#> num_partitions = {self.num_chunks}") | |
Run().print_main(f"#> num_embeddings_est = {self.num_embeddings_est}") | |
Run().print_main(f"#> avg_doclen_est = {self.avg_doclen_est}") | |
return | |
self.num_chunks = int(np.ceil(len(self.collection) / self.collection.get_chunksize())) | |
sampled_pids = self._sample_pids() | |
avg_doclen_est = self._sample_embeddings(sampled_pids) | |
# Select the number of partitions | |
num_passages = len(self.collection) | |
self.num_embeddings_est = num_passages * avg_doclen_est | |
self.num_partitions = int(2 ** np.floor(np.log2(16 * np.sqrt(self.num_embeddings_est)))) | |
Run().print_main(f'Creaing {self.num_partitions:,} partitions.') | |
Run().print_main(f'*Estimated* {int(self.num_embeddings_est):,} embeddings.') | |
self._save_plan() | |
def _sample_pids(self): | |
num_passages = len(self.collection) | |
# Simple alternative: < 100k: 100%, < 1M: 15%, < 10M: 7%, < 100M: 3%, > 100M: 1% | |
# Keep in mind that, say, 15% still means at least 100k. | |
# So the formula is max(100% * min(total, 100k), 15% * min(total, 1M), ...) | |
# Then we subsample the vectors to 100 * num_partitions | |
typical_doclen = 120 # let's keep sampling independent of the actual doc_maxlen | |
sampled_pids = 16 * np.sqrt(typical_doclen * num_passages) | |
# sampled_pids = int(2 ** np.floor(np.log2(1 + sampled_pids))) | |
sampled_pids = min(1 + int(sampled_pids), num_passages) | |
sampled_pids = random.sample(range(num_passages), sampled_pids) | |
Run().print_main(f"# of sampled PIDs = {len(sampled_pids)} \t sampled_pids[:3] = {sampled_pids[:3]}") | |
return set(sampled_pids) | |
def _sample_embeddings(self, sampled_pids): | |
local_pids = self.collection.enumerate(rank=self.rank) | |
local_sample = [passage for pid, passage in local_pids if pid in sampled_pids] | |
local_sample_embs, doclens = self.encoder.encode_passages(local_sample) | |
if torch.cuda.is_available(): | |
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cuda() | |
torch.distributed.all_reduce(self.num_sample_embs) | |
avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0 | |
avg_doclen_est = torch.tensor([avg_doclen_est]).cuda() | |
torch.distributed.all_reduce(avg_doclen_est) | |
nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cuda() | |
torch.distributed.all_reduce(nonzero_ranks) | |
else: | |
if torch.distributed.is_initialized(): | |
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cpu() | |
torch.distributed.all_reduce(self.num_sample_embs) | |
avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0 | |
avg_doclen_est = torch.tensor([avg_doclen_est]).cpu() | |
torch.distributed.all_reduce(avg_doclen_est) | |
nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cpu() | |
torch.distributed.all_reduce(nonzero_ranks) | |
else: | |
self.num_sample_embs = torch.tensor([local_sample_embs.size(0)]).cpu() | |
avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0 | |
avg_doclen_est = torch.tensor([avg_doclen_est]).cpu() | |
nonzero_ranks = torch.tensor([float(len(local_sample) > 0)]).cpu() | |
avg_doclen_est = avg_doclen_est.item() / nonzero_ranks.item() | |
self.avg_doclen_est = avg_doclen_est | |
Run().print(f'avg_doclen_est = {avg_doclen_est} \t len(local_sample) = {len(local_sample):,}') | |
torch.save(local_sample_embs.half(), os.path.join(self.config.index_path_, f'sample.{self.rank}.pt')) | |
return avg_doclen_est | |
def _try_load_plan(self): | |
config = self.config | |
self.plan_path = os.path.join(config.index_path_, 'plan.json') | |
if os.path.exists(self.plan_path): | |
with open(self.plan_path, 'r') as f: | |
try: | |
plan = ujson.load(f) | |
except Exception as e: | |
return False | |
if not ('num_chunks' in plan and | |
'num_partitions' in plan and | |
'num_embeddings_est' in plan and | |
'avg_doclen_est' in plan): | |
return False | |
# TODO: Verify config matches | |
self.num_chunks = plan['num_chunks'] | |
self.num_partitions = plan['num_partitions'] | |
self.num_embeddings_est = plan['num_embeddings_est'] | |
self.avg_doclen_est = plan['avg_doclen_est'] | |
return True | |
else: | |
return False | |
def _save_plan(self): | |
if self.rank < 1: | |
config = self.config | |
self.plan_path = os.path.join(config.index_path_, 'plan.json') | |
Run().print("#> Saving the indexing plan to", self.plan_path, "..") | |
with open(self.plan_path, 'w') as f: | |
d = {'config': config.export()} | |
d['num_chunks'] = self.num_chunks | |
d['num_partitions'] = self.num_partitions | |
d['num_embeddings_est'] = self.num_embeddings_est | |
d['avg_doclen_est'] = self.avg_doclen_est | |
f.write(ujson.dumps(d, indent=4) + '\n') | |
def train(self, shared_lists): | |
if self.rank > 0: | |
return | |
sample, heldout = self._concatenate_and_split_sample() | |
centroids = self._train_kmeans(sample, shared_lists) | |
print_memory_stats(f'RANK:{self.rank}') | |
del sample | |
bucket_cutoffs, bucket_weights, avg_residual = self._compute_avg_residual(centroids, heldout) | |
print_message(f'avg_residual = {avg_residual}') | |
codec = ResidualCodec(config=self.config, centroids=centroids, avg_residual=avg_residual, | |
bucket_cutoffs=bucket_cutoffs, bucket_weights=bucket_weights) | |
self.saver.save_codec(codec) | |
def _concatenate_and_split_sample(self): | |
print_memory_stats(f'***1*** \t RANK:{self.rank}') | |
# TODO: Allocate a float16 array. Load the samples from disk, copy to array. | |
sample = torch.empty(self.num_sample_embs, self.config.dim, dtype=torch.float16) | |
offset = 0 | |
for r in range(self.nranks): | |
sub_sample_path = os.path.join(self.config.index_path_, f'sample.{r}.pt') | |
sub_sample = torch.load(sub_sample_path) | |
os.remove(sub_sample_path) | |
endpos = offset + sub_sample.size(0) | |
sample[offset:endpos] = sub_sample | |
offset = endpos | |
assert endpos == sample.size(0), (endpos, sample.size()) | |
print_memory_stats(f'***2*** \t RANK:{self.rank}') | |
# Shuffle and split out a 5% "heldout" sub-sample [up to 50k elements] | |
sample = sample[torch.randperm(sample.size(0))] | |
print_memory_stats(f'***3*** \t RANK:{self.rank}') | |
heldout_fraction = 0.05 | |
heldout_size = int(min(heldout_fraction * sample.size(0), 50_000)) | |
sample, sample_heldout = sample.split([sample.size(0) - heldout_size, heldout_size], dim=0) | |
print_memory_stats(f'***4*** \t RANK:{self.rank}') | |
return sample, sample_heldout | |
def _train_kmeans(self, sample, shared_lists): | |
if self.use_gpu: | |
torch.cuda.empty_cache() | |
do_fork_for_faiss = False # set to True to free faiss GPU-0 memory at the cost of one more copy of `sample`. | |
args_ = [self.config.dim, self.num_partitions, self.config.kmeans_niters] | |
if do_fork_for_faiss: | |
# For this to work reliably, write the sample to disk. Pickle may not handle >4GB of data. | |
# Delete the sample file after work is done. | |
shared_lists[0][0] = sample | |
return_value_queue = mp.Queue() | |
args_ = args_ + [shared_lists, return_value_queue] | |
proc = mp.Process(target=compute_faiss_kmeans, args=args_) | |
proc.start() | |
centroids = return_value_queue.get() | |
proc.join() | |
else: | |
args_ = args_ + [[[sample]]] | |
centroids = compute_faiss_kmeans(*args_) | |
centroids = torch.nn.functional.normalize(centroids, dim=-1) | |
if self.use_gpu: | |
centroids = centroids.half() | |
else: | |
centroids = centroids.float() | |
return centroids | |
def _compute_avg_residual(self, centroids, heldout): | |
compressor = ResidualCodec(config=self.config, centroids=centroids, avg_residual=None) | |
heldout_reconstruct = compressor.compress_into_codes(heldout, out_device='cuda' if self.use_gpu else 'cpu') | |
heldout_reconstruct = compressor.lookup_centroids(heldout_reconstruct, out_device='cuda' if self.use_gpu else 'cpu') | |
if self.use_gpu: | |
heldout_avg_residual = heldout.cuda() - heldout_reconstruct | |
else: | |
heldout_avg_residual = heldout - heldout_reconstruct | |
avg_residual = torch.abs(heldout_avg_residual).mean(dim=0).cpu() | |
print([round(x, 3) for x in avg_residual.squeeze().tolist()]) | |
num_options = 2 ** self.config.nbits | |
quantiles = torch.arange(0, num_options, device=heldout_avg_residual.device) * (1 / num_options) | |
bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options) | |
bucket_cutoffs = heldout_avg_residual.float().quantile(bucket_cutoffs_quantiles) | |
bucket_weights = heldout_avg_residual.float().quantile(bucket_weights_quantiles) | |
print_message( | |
f"#> Got bucket_cutoffs_quantiles = {bucket_cutoffs_quantiles} and bucket_weights_quantiles = {bucket_weights_quantiles}") | |
print_message(f"#> Got bucket_cutoffs = {bucket_cutoffs} and bucket_weights = {bucket_weights}") | |
return bucket_cutoffs, bucket_weights, avg_residual.mean() | |
# EVENTAULLY: Compare the above with non-heldout sample. If too different, we can do better! | |
# sample = sample[subsample_idxs] | |
# sample_reconstruct = get_centroids_for(centroids, sample) | |
# sample_avg_residual = (sample - sample_reconstruct).mean(dim=0) | |
def index(self): | |
with self.saver.thread(): | |
batches = self.collection.enumerate_batches(rank=self.rank) | |
for chunk_idx, offset, passages in tqdm.tqdm(batches, disable=self.rank > 0): | |
if self.config.resume and self.saver.check_chunk_exists(chunk_idx): | |
Run().print_main(f"#> Found chunk {chunk_idx} in the index already, skipping encoding...") | |
continue | |
embs, doclens = self.encoder.encode_passages(passages) | |
if self.use_gpu: | |
assert embs.dtype == torch.float16 | |
else: | |
assert embs.dtype == torch.float32 | |
embs = embs.half() | |
Run().print_main(f"#> Saving chunk {chunk_idx}: \t {len(passages):,} passages " | |
f"and {embs.size(0):,} embeddings. From #{offset:,} onward.") | |
self.saver.save_chunk(chunk_idx, offset, embs, doclens) | |
del embs, doclens | |
def finalize(self): | |
if self.rank > 0: | |
return | |
self._check_all_files_are_saved() | |
self._collect_embedding_id_offset() | |
self._build_ivf() | |
self._update_metadata() | |
def _check_all_files_are_saved(self): | |
Run().print_main("#> Checking all files were saved...") | |
success = True | |
for chunk_idx in range(self.num_chunks): | |
if not self.saver.check_chunk_exists(chunk_idx): | |
success = False | |
Run().print_main(f"#> ERROR: Could not find chunk {chunk_idx}!") | |
#TODO: Fail here? | |
if success: | |
Run().print_main("Found all files!") | |
def _collect_embedding_id_offset(self): | |
passage_offset = 0 | |
embedding_offset = 0 | |
self.embedding_offsets = [] | |
for chunk_idx in range(self.num_chunks): | |
metadata_path = os.path.join(self.config.index_path_, f'{chunk_idx}.metadata.json') | |
with open(metadata_path) as f: | |
chunk_metadata = ujson.load(f) | |
chunk_metadata['embedding_offset'] = embedding_offset | |
self.embedding_offsets.append(embedding_offset) | |
assert chunk_metadata['passage_offset'] == passage_offset, (chunk_idx, passage_offset, chunk_metadata) | |
passage_offset += chunk_metadata['num_passages'] | |
embedding_offset += chunk_metadata['num_embeddings'] | |
with open(metadata_path, 'w') as f: | |
f.write(ujson.dumps(chunk_metadata, indent=4) + '\n') | |
self.num_embeddings = embedding_offset | |
assert len(self.embedding_offsets) == self.num_chunks | |
def _build_ivf(self): | |
# Maybe we should several small IVFs? Every 250M embeddings, so that's every 1 GB. | |
# It would save *memory* here and *disk space* regarding the int64. | |
# But we'd have to decide how many IVFs to use during retrieval: many (loop) or one? | |
# A loop seems nice if we can find a size that's large enough for speed yet small enough to fit on GPU! | |
# Then it would help nicely for batching later: 1GB. | |
Run().print_main("#> Building IVF...") | |
codes = torch.empty(self.num_embeddings,) | |
print_memory_stats(f'RANK:{self.rank}') | |
Run().print_main("#> Loading codes...") | |
for chunk_idx in tqdm.tqdm(range(self.num_chunks)): | |
offset = self.embedding_offsets[chunk_idx] | |
chunk_codes = ResidualCodec.Embeddings.load_codes(self.config.index_path_, chunk_idx) | |
codes[offset:offset+chunk_codes.size(0)] = chunk_codes | |
assert offset+chunk_codes.size(0) == codes.size(0), (offset, chunk_codes.size(0), codes.size()) | |
Run().print_main(f"Sorting codes...") | |
print_memory_stats(f'RANK:{self.rank}') | |
codes = codes.sort() | |
ivf, values = codes.indices, codes.values | |
print_memory_stats(f'RANK:{self.rank}') | |
Run().print_main(f"Getting unique codes...") | |
partitions, ivf_lengths = values.unique_consecutive(return_counts=True) | |
# All partitions should be non-empty. (We can use torch.histc otherwise.) | |
assert partitions.size(0) == self.num_partitions, (partitions.size(), self.num_partitions) | |
print_memory_stats(f'RANK:{self.rank}') | |
_, _ = optimize_ivf(ivf, ivf_lengths, self.config.index_path_) | |
def _update_metadata(self): | |
config = self.config | |
self.metadata_path = os.path.join(config.index_path_, 'metadata.json') | |
Run().print("#> Saving the indexing metadata to", self.metadata_path, "..") | |
with open(self.metadata_path, 'w') as f: | |
d = {'config': config.export()} | |
d['num_chunks'] = self.num_chunks | |
d['num_partitions'] = self.num_partitions | |
d['num_embeddings'] = self.num_embeddings | |
d['avg_doclen'] = self.num_embeddings / len(self.collection) | |
f.write(ujson.dumps(d, indent=4) + '\n') | |
def compute_faiss_kmeans(dim, num_partitions, kmeans_niters, shared_lists, return_value_queue=None): | |
use_gpu = torch.cuda.is_available() | |
kmeans = faiss.Kmeans(dim, num_partitions, niter=kmeans_niters, gpu=use_gpu, verbose=True, seed=123) | |
sample = shared_lists[0][0] | |
sample = sample.float().numpy() | |
kmeans.train(sample) | |
centroids = torch.from_numpy(kmeans.centroids) | |
print_memory_stats(f'RANK:0*') | |
if return_value_queue is not None: | |
return_value_queue.put(centroids) | |
return centroids | |
""" | |
TODOs: | |
1. Notice we're using self.config.bsize. | |
2. Consider saving/using heldout_avg_residual as a vector --- that is, using 128 averages! | |
3. Consider the operations with .cuda() tensors. Are all of them good for OOM? | |
""" | |