#include #include using torch::Tensor; using torch::IntArrayRef; std::tuple custom_ctc_loss_gpu( const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, IntArrayRef input_lengths, IntArrayRef target_lengths, double const sigma, int64_t BLANK, int64_t BLANK_1 ); std::tuple custom_ctc_loss_backward_gpu( const Tensor& grad, const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, IntArrayRef input_lengths, IntArrayRef target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, double const sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity ); std::tuple custom_ctc_loss_gpu_driver( const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, const Tensor& input_lengths, const Tensor& target_lengths, double const sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity ) { (void)zero_infinity; Tensor ilc = input_lengths.contiguous(); Tensor tlc = target_lengths.contiguous(); IntArrayRef il(ilc.data_ptr(), ilc.numel()); IntArrayRef tl(tlc.data_ptr(), tlc.numel()); return custom_ctc_loss_gpu(log_probs, targets, realval, targets_realval, il, tl, sigma, BLANK, BLANK_1); } std::tuple custom_ctc_loss_backward_gpu_driver( const Tensor& grad, const Tensor& log_probs, const Tensor& targets, const Tensor& realval, const Tensor& targets_realval, const Tensor& input_lengths, const Tensor& target_lengths, const Tensor& neg_log_likelihood, const Tensor& log_alpha, double const sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity ) { Tensor ilc = input_lengths.contiguous(); Tensor tlc = target_lengths.contiguous(); IntArrayRef il(ilc.data_ptr(), ilc.numel()); IntArrayRef tl(tlc.data_ptr(), tlc.numel()); return custom_ctc_loss_backward_gpu(grad, log_probs, targets, realval, targets_realval, il, tl, neg_log_likelihood, log_alpha, sigma, BLANK, BLANK_1, zero_infinity); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward", &custom_ctc_loss_gpu_driver, "custom CTC forward (CUDA)"); m.def("backward", &custom_ctc_loss_backward_gpu_driver, "custom CTC backward (CUDA)"); }