#include #include typedef struct decompress_args { int tid; int nthreads; int npids; int dim; int packed_dim; int npacked_vals_per_byte; int* pids; int64_t* lengths; int64_t* offsets; float* bucket_weights; uint8_t* reversed_bit_map; uint8_t* bucket_weight_combinations; uint8_t* binary_residuals; int* codes; float* centroids; int64_t* cumulative_lengths; float* output; } decompress_args_t; void* decompress(void* args) { decompress_args_t* decompress_args = (decompress_args_t*)args; int npids_per_thread = (int)std::ceil(((float)decompress_args->npids) / decompress_args->nthreads); int start = decompress_args->tid * npids_per_thread; int end = std::min((decompress_args->tid + 1) * npids_per_thread, decompress_args->npids); // Iterate over all documents for (int i = start; i < end; i++) { int pid = decompress_args->pids[i]; // Offset into packed list of token vectors for the given document int64_t offset = decompress_args->offsets[pid]; // For each document, iterate over all token vectors for (int j = 0; j < decompress_args->lengths[pid]; j++) { const int code = decompress_args->codes[offset + j]; // For each token vector, iterate over the packed (8-bit) residual // values for (int k = 0; k < decompress_args->packed_dim; k++) { uint8_t x = decompress_args->binary_residuals [(offset + j) * decompress_args->packed_dim + k]; x = decompress_args->reversed_bit_map[x]; // For each packed residual value, iterate over the bucket // weight indices. If we use n-bit compression, that means there // will be (8 / n) indices per packed value. for (int l = 0; l < decompress_args->npacked_vals_per_byte; l++) { const int output_dim_idx = k * decompress_args->npacked_vals_per_byte + l; const int bucket_weight_idx = decompress_args->bucket_weight_combinations [x * decompress_args->npacked_vals_per_byte + l]; decompress_args ->output[(decompress_args->cumulative_lengths[i] + j) * decompress_args->dim + output_dim_idx] = decompress_args->bucket_weights[bucket_weight_idx] + decompress_args->centroids[code * decompress_args->dim + output_dim_idx]; } } } } return NULL; } torch::Tensor decompress_residuals( const torch::Tensor pids, const torch::Tensor lengths, const torch::Tensor offsets, const torch::Tensor bucket_weights, const torch::Tensor reversed_bit_map, const torch::Tensor bucket_weight_combinations, const torch::Tensor binary_residuals, const torch::Tensor codes, const torch::Tensor centroids, const int dim, const int nbits) { const int npacked_vals_per_byte = (8 / nbits); const int packed_dim = (int)(dim / npacked_vals_per_byte); int npids = pids.size(0); int* pids_a = pids.data_ptr(); int64_t* lengths_a = lengths.data_ptr(); int64_t* offsets_a = offsets.data_ptr(); float* bucket_weights_a = bucket_weights.data_ptr(); uint8_t* reversed_bit_map_a = reversed_bit_map.data_ptr(); uint8_t* bucket_weight_combinations_a = bucket_weight_combinations.data_ptr(); uint8_t* binary_residuals_a = binary_residuals.data_ptr(); int* codes_a = codes.data_ptr(); float* centroids_a = centroids.data_ptr(); int64_t cumulative_lengths[npids + 1]; int noutputs = 0; cumulative_lengths[0] = 0; for (int i = 0; i < npids; i++) { noutputs += lengths_a[pids_a[i]]; cumulative_lengths[i + 1] = cumulative_lengths[i] + lengths_a[pids_a[i]]; } auto options = torch::TensorOptions().dtype(torch::kFloat32).requires_grad(false); torch::Tensor output = torch::zeros({noutputs, dim}, options); float* output_a = output.data_ptr(); auto nthreads = at::get_num_threads(); pthread_t threads[nthreads]; decompress_args_t args[nthreads]; for (int i = 0; i < nthreads; i++) { args[i].tid = i; args[i].nthreads = nthreads; args[i].npids = npids; args[i].dim = dim; args[i].packed_dim = packed_dim; args[i].npacked_vals_per_byte = npacked_vals_per_byte; args[i].pids = pids_a; args[i].lengths = lengths_a; args[i].offsets = offsets_a; args[i].bucket_weights = bucket_weights_a; args[i].reversed_bit_map = reversed_bit_map_a; args[i].bucket_weight_combinations = bucket_weight_combinations_a; args[i].binary_residuals = binary_residuals_a; args[i].codes = codes_a; args[i].centroids = centroids_a; args[i].cumulative_lengths = cumulative_lengths; args[i].output = output_a; int rc = pthread_create(&threads[i], NULL, decompress, (void*)&args[i]); if (rc) { fprintf(stderr, "Unable to create thread %d: %d\n", i, rc); std::exit(1); } } for (int i = 0; i < nthreads; i++) { pthread_join(threads[i], NULL); } return output; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("decompress_residuals_cpp", &decompress_residuals, "Decompress residuals"); }