Spaces:
Runtime error
Runtime error
File size: 4,455 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
#include <pthread.h>
#include <torch/extension.h>
#include <algorithm>
#include <numeric>
typedef struct {
int tid;
pthread_mutex_t* mutex;
std::queue<int>* queue;
int64_t ndocs;
int64_t noutputs;
int64_t dim;
void* input;
int64_t* lengths;
int64_t* offsets;
int64_t* cumulative_lengths;
void* output;
} lookup_args_t;
template <typename T>
void* lookup(void* args) {
lookup_args_t* lookup_args = (lookup_args_t*)args;
int64_t* lengths = lookup_args->lengths;
int64_t* cumulative_lengths = lookup_args->cumulative_lengths;
int64_t* offsets = lookup_args->offsets;
int64_t dim = lookup_args->dim;
T* input = static_cast<T*>(lookup_args->input);
T* output = static_cast<T*>(lookup_args->output);
while (1) {
pthread_mutex_lock(lookup_args->mutex);
if (lookup_args->queue->empty()) {
pthread_mutex_unlock(lookup_args->mutex);
return NULL;
}
int i = lookup_args->queue->front();
lookup_args->queue->pop();
pthread_mutex_unlock(lookup_args->mutex);
std::memcpy(output + (cumulative_lengths[i] * dim),
input + (offsets[i] * dim), lengths[i] * dim * sizeof(T));
}
}
template <typename T>
torch::Tensor segmented_lookup_impl(const torch::Tensor input,
const torch::Tensor pids,
const torch::Tensor lengths,
const torch::Tensor offsets) {
auto lengths_a = lengths.data_ptr<int64_t>();
auto offsets_a = offsets.data_ptr<int64_t>();
int64_t ndocs = pids.size(0);
int64_t noutputs = std::accumulate(lengths_a, lengths_a + ndocs, 0);
int nthreads = at::get_num_threads();
int64_t dim;
torch::Tensor output;
if (input.dim() == 1) {
dim = 1;
output = torch::zeros({noutputs}, input.options());
} else {
assert(input.dim() == 2);
dim = input.size(1);
output = torch::zeros({noutputs, dim}, input.options());
}
int64_t cumulative_lengths[ndocs + 1];
cumulative_lengths[0] = 0;
std::partial_sum(lengths_a, lengths_a + ndocs, cumulative_lengths + 1);
pthread_mutex_t mutex;
int rc = pthread_mutex_init(&mutex, NULL);
if (rc) {
fprintf(stderr, "Unable to init mutex: %d\n", rc);
}
std::queue<int> queue;
for (int i = 0; i < ndocs; i++) {
queue.push(i);
}
pthread_t threads[nthreads];
lookup_args_t args[nthreads];
for (int i = 0; i < nthreads; i++) {
args[i].tid = i;
args[i].mutex = &mutex;
args[i].queue = &queue;
args[i].ndocs = ndocs;
args[i].noutputs = noutputs;
args[i].dim = dim;
args[i].input = (void*)input.data_ptr<T>();
args[i].lengths = lengths_a;
args[i].offsets = offsets_a;
args[i].cumulative_lengths = cumulative_lengths;
args[i].output = (void*)output.data_ptr<T>();
rc = pthread_create(&threads[i], NULL, lookup<T>, (void*)&args[i]);
if (rc) {
fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
}
}
for (int i = 0; i < nthreads; i++) {
pthread_join(threads[i], NULL);
}
rc = pthread_mutex_destroy(&mutex);
if (rc) {
fprintf(stderr, "Unable to destroy mutex: %d\n", rc);
}
return output;
}
torch::Tensor segmented_lookup(const torch::Tensor input,
const torch::Tensor pids,
const torch::Tensor lengths,
const torch::Tensor offsets) {
if (input.dtype() == torch::kUInt8) {
return segmented_lookup_impl<uint8_t>(input, pids, lengths, offsets);
} else if (input.dtype() == torch::kInt32) {
return segmented_lookup_impl<int>(input, pids, lengths, offsets);
} else if (input.dtype() == torch::kInt64) {
return segmented_lookup_impl<int64_t>(input, pids, lengths, offsets);
} else if (input.dtype() == torch::kFloat32) {
return segmented_lookup_impl<float>(input, pids, lengths, offsets);
} else if (input.dtype() == torch::kFloat16) {
return segmented_lookup_impl<at::Half>(input, pids, lengths, offsets);
} else {
assert(false);
}
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("segmented_lookup_cpp", &segmented_lookup, "Segmented lookup");
}
|