|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <torch/extension.h> |
|
|
|
#include <ATen/TensorUtils.h> |
|
#include <c10/util/Exception.h> |
|
#include <c10/util/MathConstants.h> |
|
#include <c10/macros/Macros.h> |
|
|
|
#include <ATen/ATen.h> |
|
#include <ATen/Dispatch.h> |
|
#include <ATen/cuda/CUDAApplyUtils.cuh> |
|
|
|
#include <THC/THCAtomics.cuh> |
|
|
|
#include <type_traits> |
|
#include <numeric> |
|
|
|
using namespace c10; |
|
using namespace at; |
|
using namespace at::native; |
|
|
|
|
|
template<typename scalar_t> |
|
__device__ inline scalar_t custom_distance_forward_log(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { |
|
return -0.5 * std::log(2.0 * c10::pi<scalar_t>) - std::log(sigma) - 0.5 * (x - mu) * (x - mu) / (sigma * sigma); |
|
} |
|
|
|
|
|
template<typename scalar_t> |
|
__device__ inline scalar_t custom_distance_backward(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { |
|
scalar_t val = 1.0 / (sigma * std::sqrt(2 * c10::pi<scalar_t>)) * std::exp(-0.5 * (x - mu) * (x - mu) / (sigma * sigma)); |
|
return val * (x - mu) / (sigma * sigma); |
|
} |
|
|
|
|
|
template<typename scalar_t> |
|
__device__ inline scalar_t custom_distance_forward_log_l1(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { |
|
return - std::log(2 * sigma) - std::abs(x - mu) / sigma; |
|
} |
|
|
|
|
|
template<typename scalar_t> |
|
__device__ inline scalar_t sgn(scalar_t v) noexcept { |
|
if (std::abs(v) < std::numeric_limits<scalar_t>::epsilon()) |
|
return 0; |
|
return v / std::abs(v); |
|
} |
|
|
|
|
|
template<typename scalar_t> |
|
__device__ inline scalar_t custom_distance_backward_l1(scalar_t x, scalar_t mu, scalar_t sigma) noexcept { |
|
return -sgn(mu - x) * std::exp(-std::abs(x - mu) / sigma) / (2 * sigma * sigma); |
|
} |
|
|
|
#if 0 |
|
|
|
template<typename scalar_t> |
|
__device__ inline scalar_t custom_distance_backward_log(scalar_t x, scalar_t mu) { |
|
return x - mu; |
|
} |
|
|
|
|
|
template<typename scalar_t> |
|
__device__ inline scalar_t custom_distance_forward(scalar_t x, scalar_t mu) { |
|
return 0; |
|
} |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename target_t> |
|
__device__ static inline int64_t get_target_prime( |
|
const target_t* __restrict__ target, |
|
int64_t offset, |
|
int64_t stride, |
|
int64_t idx, |
|
int64_t BLANK) { |
|
if (idx % 2 == 0) { |
|
return BLANK; |
|
} else { |
|
return target[offset + stride * (idx / 2)]; |
|
} |
|
} |
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_collect_log_realvalues_gpu_kernel(scalar_t* __restrict__ log_realvalues_data, |
|
const int64_t* __restrict__ input_lengths, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, |
|
const scalar_t* __restrict__ realval_data, int64_t num_realval, |
|
const scalar_t* __restrict__ targets_realval_data, |
|
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, |
|
int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride, |
|
int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { |
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
int64_t s = threadIdx.x + blockIdx.x * blockDim.x; |
|
|
|
if (b >= batch_size) |
|
return; |
|
|
|
int64_t input_length = input_lengths[b]; |
|
int64_t target_length = target_lengths[b]; |
|
int64_t rv_batch_offset = b*rv_batch_stride; |
|
int64_t rvt_batch_offset = b*rvt_batch_stride; |
|
int64_t lr_batch_offset = b*lr_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
if (s >= target_length) |
|
return; |
|
|
|
int64_t target = targets_data[tg_batch_offset + s * tg_target_stride]; |
|
|
|
for (int64_t t = 0; t < input_length; t++) { |
|
scalar_t log_prod_n = 0; |
|
if (target != BLANK && target != BLANK_1) { |
|
for (int64_t i = 0; i < num_realval; ++i) { |
|
log_prod_n += custom_distance_forward_log( |
|
targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i], |
|
realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i], |
|
sigma |
|
); |
|
} |
|
} |
|
log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s] = log_prod_n; |
|
} |
|
} |
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_log_alpha_gpu_kernel_phase1(scalar_t* __restrict__ log_alpha_data, |
|
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, |
|
const scalar_t* __restrict__ log_realvalues_data, |
|
scalar_t* __restrict__ neg_log_likelihood_data, |
|
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, |
|
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, |
|
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { |
|
|
|
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity(); |
|
|
|
|
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
|
|
if (b >= batch_size) |
|
return; |
|
|
|
int64_t input_length = input_lengths[b]; |
|
int64_t target_length = target_lengths[b]; |
|
int64_t lp_batch_offset = b*lp_batch_stride; |
|
int64_t la_batch_offset = b*la_batch_stride; |
|
int64_t lr_batch_offset = b*lr_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
|
|
for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) { |
|
int64_t s = threadIdx.x + block_s; |
|
scalar_t la; |
|
switch (s) { |
|
case 0: |
|
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK]; |
|
break; |
|
case 1: |
|
{ |
|
if (target_length != 0) { |
|
int64_t tgt = get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
1, |
|
BLANK); |
|
scalar_t cur_logprob = log_probs_data[lp_batch_offset + lp_char_stride * tgt]; |
|
|
|
cur_logprob += log_realvalues_data[lr_batch_offset + lr_input_stride * 0 + lr_target_stride * 0]; |
|
|
|
la = cur_logprob; |
|
} else { |
|
la = neginf; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
break; |
|
default: |
|
la = neginf; |
|
} |
|
if (s < 2*max_target_length+1) |
|
log_alpha_data[la_batch_offset + + la_target_stride * s] = la; |
|
} |
|
} |
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_log_alpha_gpu_kernel_phase2(scalar_t* __restrict__ log_alpha_data, |
|
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, |
|
const scalar_t* __restrict__ log_realvalues_data, |
|
scalar_t* __restrict__ neg_log_likelihood_data, |
|
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, |
|
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, |
|
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { |
|
|
|
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity(); |
|
|
|
|
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
|
|
if (b >= batch_size) |
|
return; |
|
|
|
int64_t input_length = input_lengths[b]; |
|
int64_t target_length = target_lengths[b]; |
|
int64_t lp_batch_offset = b*lp_batch_stride; |
|
int64_t la_batch_offset = b*la_batch_stride; |
|
int64_t lr_batch_offset = b*lr_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
|
|
for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) { |
|
int64_t s = threadIdx.x + block_s; |
|
scalar_t la; |
|
switch (s) { |
|
case 0: |
|
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK]; |
|
break; |
|
case 1: |
|
{ |
|
if (target_length != 0) { |
|
int64_t tgt = get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
1, |
|
BLANK); |
|
scalar_t cur_logprob = log_probs_data[lp_batch_offset + lp_char_stride * tgt]; |
|
|
|
cur_logprob += log_realvalues_data[lr_batch_offset + lr_input_stride * 0 + lr_target_stride * 0]; |
|
|
|
la = cur_logprob; |
|
} else { |
|
la = neginf; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
break; |
|
default: |
|
la = neginf; |
|
} |
|
if (s < 2*max_target_length+1) |
|
log_alpha_data[la_batch_offset + + la_target_stride * s] = la; |
|
} |
|
|
|
for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) { |
|
int64_t s = threadIdx.x + block_s; |
|
|
|
|
|
int64_t current_char; |
|
bool have_three; |
|
if (s < 2 * target_length + 1 && target_length > 0) { |
|
current_char = get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
s, |
|
BLANK); |
|
have_three = |
|
((s > 1) && |
|
(get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
s - 2, |
|
BLANK) != current_char)); |
|
} else { |
|
current_char = BLANK; |
|
have_three = false; |
|
} |
|
for (int64_t t=1; t < max_input_length; t++) { |
|
__syncthreads(); |
|
if ((t < input_length) && (s < 2 * target_length + 1)) { |
|
scalar_t cur_logprob = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_char]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cur_logprob += (s % 2 == 1) ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0; |
|
|
|
|
|
scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s]; |
|
scalar_t lamax = la1; |
|
scalar_t la2, la3; |
|
if (s > 0) { |
|
la2 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-1)]; |
|
if (la2 > lamax) |
|
lamax = la2; |
|
} else { |
|
la2 = neginf; |
|
} |
|
if (have_three) { |
|
la3 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-2)]; |
|
if (la3 > lamax) |
|
lamax = la3; |
|
} else { |
|
la3 = neginf; |
|
} |
|
if (lamax == neginf) |
|
lamax = 0; |
|
|
|
log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = std::log(std::exp(la1-lamax)+std::exp(la2-lamax)+std::exp(la3-lamax))+lamax |
|
+ cur_logprob; |
|
} else { |
|
|
|
if (s < 2*max_target_length+1) |
|
log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = neginf; |
|
} |
|
} |
|
} |
|
__syncthreads(); |
|
|
|
|
|
if (threadIdx.x == 0) { |
|
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)]; |
|
scalar_t l2 = target_length > 0 |
|
? log_alpha_data |
|
[la_batch_offset + la_input_stride * (input_length - 1) + |
|
la_target_stride * (target_length * 2 - 1)] |
|
: neginf; |
|
scalar_t m = ((l1 > l2) ? l1 : l2); |
|
m = ((m == neginf) ? 0 : m); |
|
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m; |
|
neg_log_likelihood_data[b] = -log_likelihood; |
|
} |
|
} |
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_log_alpha_gpu_kernel_phase3(scalar_t* __restrict__ log_alpha_data, |
|
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, |
|
const scalar_t* __restrict__ realval_data, int64_t num_realval, |
|
const scalar_t* __restrict__ targets_realval_data, |
|
scalar_t* __restrict__ neg_log_likelihood_data, |
|
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, |
|
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, |
|
int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride, |
|
int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { |
|
|
|
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity(); |
|
|
|
|
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
|
|
if (b >= batch_size) |
|
return; |
|
|
|
int64_t input_length = input_lengths[b]; |
|
int64_t target_length = target_lengths[b]; |
|
int64_t lp_batch_offset = b*lp_batch_stride; |
|
int64_t la_batch_offset = b*la_batch_stride; |
|
int64_t rv_batch_offset = b*rv_batch_stride; |
|
int64_t rvt_batch_offset = b*rvt_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
|
|
|
|
|
|
if (threadIdx.x == 0) { |
|
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)]; |
|
scalar_t l2 = target_length > 0 |
|
? log_alpha_data |
|
[la_batch_offset + la_input_stride * (input_length - 1) + |
|
la_target_stride * (target_length * 2 - 1)] |
|
: neginf; |
|
scalar_t m = ((l1 > l2) ? l1 : l2); |
|
m = ((m == neginf) ? 0 : m); |
|
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m; |
|
neg_log_likelihood_data[b] = -log_likelihood; |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename scalar_t, ScalarType target_scalar_type> |
|
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu_template( |
|
const Tensor& log_probs, |
|
const Tensor& targets, |
|
const Tensor& realval, |
|
const Tensor& targets_realval, |
|
IntArrayRef input_lengths, |
|
IntArrayRef target_lengths, |
|
scalar_t const sigma, |
|
int64_t BLANK, |
|
int64_t BLANK_1 |
|
) { |
|
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity(); |
|
|
|
|
|
|
|
|
|
|
|
CheckedFrom c = "custom_ctc_loss_gpu"; |
|
using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type; |
|
auto log_probs_arg = TensorArg(log_probs, "log_probs", 1); |
|
auto targets_arg = TensorArg(targets, "targets", 2); |
|
auto realval_arg = TensorArg(realval, "realval", 3); |
|
auto targets_realval_arg = TensorArg(targets_realval, "targets_realval", 4); |
|
checkAllSameGPU(c, {log_probs_arg, targets_arg, realval_arg, targets_realval_arg}); |
|
|
|
checkScalarType(c, targets_arg, target_scalar_type); |
|
checkDim(c, log_probs_arg, 3); |
|
checkDim(c, realval_arg, 3); |
|
checkDim(c, targets_realval_arg, 3); |
|
checkDimRange(c, targets_arg, 1, 3); |
|
|
|
int64_t batch_size = log_probs.size(0); |
|
int64_t num_realvals = realval.size(2); |
|
int64_t num_labels = log_probs.size(2); |
|
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range"); |
|
TORCH_CHECK((0 <= BLANK_1) && (BLANK_1 < num_labels), "blank1 must be in label range"); |
|
TORCH_CHECK(input_lengths.size() == batch_size, "input_lengths must be of size batch_size"); |
|
TORCH_CHECK(realval.size(2) == targets_realval.size(2), "number of real values must be the same for both realval and targets_realval"); |
|
TORCH_CHECK(log_probs.size(1) == realval.size(1), "input_lengths must be the same for both log_probs and realval"); |
|
TORCH_CHECK(target_lengths.size() == batch_size, "target_lengths must be of size batch_size"); |
|
|
|
int64_t lp_input_stride = log_probs.stride(1); |
|
int64_t lp_char_stride = log_probs.stride(2); |
|
int64_t tg_target_stride; |
|
|
|
int64_t max_target_length = 0; |
|
auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong)); |
|
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>(); |
|
if (targets.dim() == 1) { |
|
int64_t pos = 0; |
|
for (int64_t i = 0; i < batch_size; i++) { |
|
tg_batch_offsets_data[i] = pos; |
|
pos += target_lengths[i]; |
|
if (max_target_length < target_lengths[i]) |
|
max_target_length = target_lengths[i]; |
|
} |
|
tg_target_stride = targets.stride(0); |
|
checkSize(c, targets_arg, 0, pos); |
|
} |
|
else { |
|
|
|
int64_t tg_batch_stride = targets.stride(0); |
|
for (int64_t i = 0; i < batch_size; i++) { |
|
tg_batch_offsets_data[i] = i * tg_batch_stride; |
|
if (max_target_length < target_lengths[i]) |
|
max_target_length = target_lengths[i]; |
|
} |
|
tg_target_stride = targets.stride(1); |
|
checkSize(c, targets_arg, 0, batch_size); |
|
TORCH_CHECK(targets.size(1) >= max_target_length, |
|
"Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg, |
|
" (while checking arguments for ", c, ")"); |
|
} |
|
int64_t max_input_length = log_probs.size(1); |
|
for (int64_t b = 0; b < batch_size; b++) { |
|
TORCH_CHECK(input_lengths[b] <= max_input_length, |
|
"Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b], |
|
" (while checking arguments for ", c, ")"); |
|
} |
|
|
|
auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong)); |
|
auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong)); |
|
tg_batch_offsets = tg_batch_offsets.cuda(); |
|
|
|
Tensor log_realvalues = at::zeros({batch_size, log_probs.size(1), std::max(max_target_length, int64_t(1))}, log_probs.options()); |
|
Tensor log_alpha = at::empty({batch_size, log_probs.size(1), 2*max_target_length+1}, log_probs.options()); |
|
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options()); |
|
log_alpha.fill_(neginf); |
|
|
|
constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 896; |
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
{ |
|
int threads_target = max_threads; |
|
while (threads_target / 2 >= max_target_length && threads_target > 1) { |
|
threads_target /= 2; |
|
} |
|
int threads_batch = std::min(max_threads / threads_target, (int) batch_size); |
|
dim3 block(threads_target, threads_batch); |
|
dim3 grid( |
|
std::max<int>( |
|
(max_target_length + threads_target - 1) / threads_target, 1), |
|
(batch_size + threads_batch - 1) / threads_batch, |
|
1); |
|
ctc_loss_collect_log_realvalues_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>> |
|
(log_realvalues.data_ptr<scalar_t>(), |
|
input_lengths_t.data_ptr<int64_t>(), |
|
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), |
|
realval.data_ptr<scalar_t>(), num_realvals, |
|
targets_realval.data_ptr<scalar_t>(), |
|
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), |
|
realval.stride(0), realval.stride(1), realval.stride(2), |
|
targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2), |
|
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, |
|
batch_size, num_labels, sigma, BLANK, BLANK_1); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
|
|
int threads_target = max_threads; |
|
while (threads_target / 2 >= 2*max_target_length+1) { |
|
threads_target /= 2; |
|
} |
|
int threads_batch = std::min(max_threads / threads_target, (int) batch_size); |
|
dim3 block(threads_target, threads_batch); |
|
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctc_loss_log_alpha_gpu_kernel_phase2<scalar_t, target_t><<<grid, block, 0, stream>>>( |
|
log_alpha.data_ptr<scalar_t>(), |
|
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1), |
|
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length, |
|
log_realvalues.data_ptr<scalar_t>(), |
|
neg_log_likelihood.data_ptr<scalar_t>(), |
|
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), |
|
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), |
|
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), |
|
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, |
|
batch_size, sigma, BLANK, BLANK_1); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return std::make_tuple(neg_log_likelihood, log_alpha); |
|
} |
|
|
|
|
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data, |
|
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, |
|
const scalar_t* __restrict__ log_realvalues_data, |
|
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, |
|
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, |
|
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { |
|
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity(); |
|
|
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
|
|
if (b >= batch_size) |
|
return; |
|
|
|
int64_t input_length = input_lengths[b]; |
|
int64_t target_length = target_lengths[b]; |
|
int64_t lp_batch_offset = b*lp_batch_stride; |
|
int64_t lb_batch_offset = b*lb_batch_stride; |
|
int64_t lr_batch_offset = b*lr_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
|
|
|
|
for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) { |
|
int64_t s = threadIdx.x + block_s; |
|
scalar_t lb; |
|
if (s == 2*target_length) { |
|
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK]; |
|
} else if (s == 2 * target_length - 1) { |
|
int64_t current_target_prime = get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
s, |
|
BLANK); |
|
scalar_t cur_logprob = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime]; |
|
lb = cur_logprob + log_realvalues_data[lr_batch_offset + lr_input_stride * (input_length - 1) + lr_target_stride * (target_length - 1)]; |
|
} else { |
|
lb = neginf; |
|
} |
|
if (s < 2*max_target_length+1) { |
|
log_beta_data[lb_batch_offset + (input_length-1) * lb_input_stride + lb_target_stride * s] = lb; |
|
} |
|
} |
|
|
|
|
|
for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) { |
|
int64_t s = threadIdx.x + block_s; |
|
int64_t current_target_prime; |
|
bool have_three; |
|
if (s < 2 * target_length + 1 && target_length > 0) { |
|
current_target_prime = get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
s, |
|
BLANK); |
|
have_three = |
|
((s < 2 * target_length - 1) && |
|
(get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
s + 2, |
|
BLANK) != current_target_prime)); |
|
} else { |
|
current_target_prime = BLANK; |
|
have_three = false; |
|
} |
|
|
|
for (int64_t t=max_input_length-2; t>=0; t--) { |
|
__syncthreads(); |
|
if ((t < input_length - 1) && (s < 2 * target_length + 1)) { |
|
scalar_t cur_logprob = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime]; |
|
cur_logprob += (s % 2 == 1) ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0; |
|
scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s]; |
|
scalar_t lbmax = lb1; |
|
scalar_t lb2, lb3; |
|
|
|
if (s < 2*target_length) { |
|
lb2 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+1)]; |
|
if (lb2 > lbmax) |
|
lbmax = lb2; |
|
} else { |
|
lb2 = neginf; |
|
} |
|
if (have_three) { |
|
lb3 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+2)]; |
|
if (lb3 > lbmax) |
|
lbmax = lb3; |
|
} else { |
|
lb3 = neginf; |
|
} |
|
if (lbmax == neginf) |
|
lbmax = 0; |
|
|
|
scalar_t lb = std::log(std::exp(lb1-lbmax)+std::exp(lb2-lbmax)+std::exp(lb3-lbmax))+lbmax |
|
+ cur_logprob; |
|
|
|
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb; |
|
} |
|
else if ( |
|
(s < 2 * max_target_length + 1) && |
|
(((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) || |
|
(t >= input_length))) { |
|
log_beta_data |
|
[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = |
|
neginf; |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data, |
|
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, |
|
const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, |
|
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, |
|
const scalar_t* __restrict__ log_realvalues_data, |
|
const scalar_t* __restrict__ neg_log_likelihood_data, |
|
int64_t gr_batch_stride, int64_t gr_input_stride, int64_t gr_char_stride, |
|
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, |
|
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, |
|
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, |
|
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity) { |
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
int64_t s = threadIdx.x + blockIdx.x * blockDim.x; |
|
|
|
if (b >= batch_size) |
|
return; |
|
|
|
int64_t input_length = input_lengths[b]; |
|
int64_t target_length = target_lengths[b]; |
|
int64_t gr_batch_offset = b*gr_batch_stride; |
|
int64_t lp_batch_offset = b*lp_batch_stride; |
|
int64_t la_batch_offset = b*la_batch_stride; |
|
int64_t lb_batch_offset = b*lb_batch_stride; |
|
int64_t lr_batch_offset = b*lr_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
if (s >= target_length) |
|
return; |
|
|
|
int64_t target = targets_data[tg_batch_offset + s * tg_target_stride]; |
|
scalar_t nll = neg_log_likelihood_data[b]; |
|
scalar_t gr = grad_out_data[b * grad_out_batch_stride]; |
|
|
|
if (zero_infinity && nll == std::numeric_limits<scalar_t>::infinity()) |
|
return; |
|
|
|
for (int64_t t = 0; t < input_length; t++) { |
|
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * target]; |
|
scalar_t log_alpha_beta = log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s*2+1)] + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s*2+1)]; |
|
scalar_t log_prod_n = log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s]; |
|
scalar_t log_alpha_beta_div_pr = log_alpha_beta - log_prod_n; |
|
gpuAtomicAddNoReturn(&gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * target], |
|
-std::exp(log_alpha_beta_div_pr + nll - lp) * gr); |
|
} |
|
} |
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_backward_collect_realvalue_gpu_kernel(scalar_t* __restrict__ gradient_realval_data, |
|
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, |
|
const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, |
|
const scalar_t* log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, |
|
const scalar_t* __restrict__ realval_data, int64_t num_realval, |
|
const scalar_t* __restrict__ targets_realval_data, |
|
const scalar_t* __restrict__ log_realvalues_data, |
|
const scalar_t* __restrict__ neg_log_likelihood_data, |
|
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, |
|
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, |
|
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, |
|
int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride, |
|
int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride, |
|
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) { |
|
|
|
|
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
int64_t t = threadIdx.x + blockIdx.x * blockDim.x; |
|
|
|
if ((t >= max_input_length) || (b >= batch_size)) |
|
return; |
|
|
|
|
|
int64_t target_length = target_lengths[b]; |
|
int64_t lp_batch_offset = b*lp_batch_stride; |
|
int64_t la_batch_offset = b*la_batch_stride; |
|
int64_t lb_batch_offset = b*lb_batch_stride; |
|
int64_t rv_batch_offset = b*rv_batch_stride; |
|
int64_t rvt_batch_offset = b*rvt_batch_stride; |
|
int64_t lr_batch_offset = b*lr_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
scalar_t nll = neg_log_likelihood_data[b]; |
|
scalar_t gr = grad_out_data[b * grad_out_batch_stride]; |
|
|
|
|
|
for (int s = 0; s < max_target_length; s++) { |
|
if (s < target_length) { |
|
int64_t current_target_prime = get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
s * 2 + 1, |
|
BLANK); |
|
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime]; |
|
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s * 2 + 1)] |
|
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s * 2 + 1)]); |
|
scalar_t log_prod_n = log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s]; |
|
if (current_target_prime != BLANK && current_target_prime != BLANK_1) { |
|
scalar_t log_term1 = log_alpha_beta - lp - 2 * log_prod_n; |
|
for (int64_t i = 0; i != num_realval; ++i) { |
|
scalar_t log_constant_factors = log_prod_n - custom_distance_forward_log( |
|
targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i], |
|
realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i], |
|
static_cast<scalar_t>(sigma) |
|
); |
|
scalar_t grad_dp_dmu = std::exp(log_term1 + log_constant_factors + nll) * custom_distance_backward( |
|
targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i], |
|
realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i], |
|
static_cast<scalar_t>(sigma) |
|
); |
|
gradient_realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i] += -grad_dp_dmu * gr; |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
template<typename scalar_t, typename target_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data, |
|
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride, |
|
const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data, |
|
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length, |
|
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length, |
|
const scalar_t* __restrict__ log_realvalues_data, |
|
const scalar_t* __restrict__ neg_log_likelihood_data, |
|
int64_t gr_batch_stride, int64_t gr_input_stride, int64_t gr_char_stride, |
|
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride, |
|
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride, |
|
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride, |
|
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride, |
|
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride, |
|
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity) { |
|
|
|
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity(); |
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
int64_t t = threadIdx.x + blockIdx.x * blockDim.x; |
|
|
|
if ((t >= max_input_length) || (b >= batch_size)) |
|
return; |
|
|
|
int64_t input_length = input_lengths[b]; |
|
int64_t target_length = target_lengths[b]; |
|
int64_t gr_batch_offset = b*gr_batch_stride; |
|
int64_t lp_batch_offset = b*lp_batch_stride; |
|
int64_t la_batch_offset = b*la_batch_stride; |
|
int64_t lb_batch_offset = b*lb_batch_stride; |
|
int64_t lr_batch_offset = b*lr_batch_stride; |
|
int64_t tg_batch_offset = tg_batch_offsets[b]; |
|
|
|
|
|
|
|
for (int s = 0; s < 2*max_target_length+1; s++) { |
|
if (s < 2 * target_length + 1) { |
|
int64_t current_target_prime = get_target_prime( |
|
targets_data, |
|
tg_batch_offset, |
|
tg_target_stride, |
|
s, |
|
BLANK); |
|
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] |
|
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]); |
|
scalar_t log_prod_n = s % 2 == 1 ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0; |
|
scalar_t log_alpha_beta_div_pr = log_alpha_beta - log_prod_n; |
|
scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime]; |
|
if (lcab == neginf) { |
|
lcab = log_alpha_beta_div_pr; |
|
} else { |
|
scalar_t max = ((lcab > log_alpha_beta_div_pr) ? lcab : log_alpha_beta_div_pr); |
|
lcab = std::log(std::exp(lcab-max)+std::exp(log_alpha_beta_div_pr-max))+max; |
|
} |
|
} |
|
} |
|
|
|
scalar_t nll = neg_log_likelihood_data[b]; |
|
scalar_t gr = grad_out_data[b * grad_out_batch_stride]; |
|
|
|
for (int64_t c = 0; c < num_labels; c++) { |
|
scalar_t& res = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * c]; |
|
if (t < input_length && (! zero_infinity || nll != std::numeric_limits<scalar_t>::infinity())) { |
|
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * c]; |
|
res = (std::exp(lp)-std::exp(res + nll - lp)) * gr; |
|
} |
|
else { |
|
res = 0.; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
template<typename scalar_t> |
|
__global__ void |
|
#if defined (__HIP_PLATFORM_HCC__) |
|
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1) |
|
#endif |
|
ctc_loss_zero_padded_gradients( |
|
scalar_t* __restrict__ gradient_data, |
|
const int64_t* __restrict__ input_lengths, |
|
int64_t gr_batch_stride, |
|
int64_t gr_timestep_stride, |
|
int64_t gr_label_stride, |
|
int64_t batch_size, |
|
int64_t max_input_length, |
|
int64_t num_labels |
|
) { |
|
|
|
int64_t b = threadIdx.y + blockIdx.y * blockDim.y; |
|
int64_t t = threadIdx.x + blockIdx.x * blockDim.x; |
|
|
|
if (b >= batch_size || t >= max_input_length) { |
|
return; |
|
} |
|
|
|
scalar_t input_length = input_lengths[b]; |
|
if (t >= input_length) { |
|
for (int l = 0; l < num_labels; l++) |
|
gradient_data[ |
|
b * gr_batch_stride + t * gr_timestep_stride + l * gr_label_stride] |
|
= 0.0f; |
|
} |
|
} |
|
|
|
|
|
|
|
template<typename scalar_t, ScalarType target_scalar_type> |
|
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu_template( |
|
const Tensor& grad_out, |
|
const Tensor& log_probs, |
|
const Tensor& targets, |
|
const Tensor& realval, |
|
const Tensor& targets_realval, |
|
IntArrayRef input_lengths, |
|
IntArrayRef target_lengths, |
|
const Tensor& neg_log_likelihood, |
|
const Tensor& log_alpha, |
|
scalar_t const sigma, |
|
int64_t BLANK, |
|
int64_t BLANK_1, |
|
bool zero_infinity |
|
) { |
|
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity(); |
|
using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type; |
|
int64_t batch_size = log_probs.size(0); |
|
int64_t num_realvals = realval.size(2); |
|
int64_t num_labels = log_probs.size(2); |
|
int64_t lp_input_stride = log_probs.stride(1); |
|
int64_t lp_char_stride = log_probs.stride(2); |
|
int64_t tg_target_stride; |
|
|
|
int64_t max_target_length; |
|
auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong))); |
|
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>(); |
|
if (targets.dim() == 1) { |
|
int64_t pos = 0; |
|
max_target_length = 0; |
|
for (int64_t i = 0; i < batch_size; i++) { |
|
tg_batch_offsets_data[i] = pos; |
|
pos += target_lengths[i]; |
|
if (max_target_length < target_lengths[i]) |
|
max_target_length = target_lengths[i]; |
|
} |
|
tg_target_stride = targets.stride(0); |
|
} |
|
else { |
|
|
|
int64_t tg_batch_stride = targets.stride(0); |
|
for (int64_t i = 0; i < batch_size; i++) { |
|
tg_batch_offsets_data[i] = i * tg_batch_stride; |
|
} |
|
tg_target_stride = targets.stride(1); |
|
max_target_length = log_alpha.size(2)/2; |
|
} |
|
auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong)); |
|
auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong)); |
|
tg_batch_offsets = tg_batch_offsets.cuda(); |
|
|
|
Tensor log_realvalues = at::zeros({batch_size, log_probs.size(1), std::max(max_target_length, int64_t(1))}, log_alpha.options()); |
|
Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
log_beta.fill_(neginf); |
|
|
|
Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
Tensor grad_realval = at::full_like(realval, 0, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
|
|
|
|
|
constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 896; |
|
int threads_target = max_threads; |
|
while (threads_target / 2 >= 2*max_target_length+1) { |
|
threads_target /= 2; |
|
} |
|
int threads_batch = std::min(max_threads / threads_target, (int) batch_size); |
|
|
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
|
|
|
{ |
|
int threads_target = max_threads; |
|
while (threads_target / 2 >= max_target_length && threads_target > 1) { |
|
threads_target /= 2; |
|
} |
|
int threads_batch = std::min(max_threads / threads_target, (int) batch_size); |
|
dim3 block(threads_target, threads_batch); |
|
dim3 grid( |
|
std::max<int>( |
|
(max_target_length + threads_target - 1) / threads_target, 1), |
|
(batch_size + threads_batch - 1) / threads_batch, |
|
1); |
|
ctc_loss_collect_log_realvalues_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>> |
|
(log_realvalues.data_ptr<scalar_t>(), |
|
input_lengths_t.data_ptr<int64_t>(), |
|
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), |
|
realval.data_ptr<scalar_t>(), num_realvals, |
|
targets_realval.data_ptr<scalar_t>(), |
|
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), |
|
realval.stride(0), realval.stride(1), realval.stride(2), |
|
targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2), |
|
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, |
|
batch_size, num_labels, sigma, BLANK, BLANK_1); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
{ |
|
dim3 block(threads_target, threads_batch); |
|
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch); |
|
ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>> |
|
(log_beta.data_ptr<scalar_t>(), |
|
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1), |
|
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length, |
|
log_realvalues.data_ptr<scalar_t>(), |
|
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), |
|
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), |
|
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), |
|
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, |
|
batch_size, sigma, BLANK, BLANK_1); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool is_large = (2*log_probs.size(1)+(24*batch_size)/10+(2*num_labels)/10) > 450; |
|
if (is_large) { |
|
|
|
at::exp_out(grad, log_probs); |
|
|
|
|
|
|
|
auto grad_blank = grad.narrow(2, BLANK, 1); |
|
grad_blank -= (at::logsumexp(log_alpha.as_strided({batch_size, log_alpha.size(1), max_target_length+1}, |
|
{log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2)*2}) |
|
+ log_beta.as_strided({batch_size, log_beta.size(1), max_target_length+1}, |
|
{log_beta.stride(0), log_beta.stride(1), log_beta.stride(2)*2}), |
|
2, true) |
|
.add_(neg_log_likelihood.view({batch_size, 1, 1})) |
|
.sub_(log_probs.narrow(2, BLANK, 1)) |
|
.exp_() |
|
); |
|
|
|
grad *= grad_out.view({batch_size, 1, 1}); |
|
if (zero_infinity) { |
|
grad = at::where(neg_log_likelihood.view({batch_size, 1, 1}) == Scalar(std::numeric_limits<scalar_t>::infinity()), at::zeros({}, grad.options()), grad); |
|
} |
|
|
|
|
|
|
|
int threads_target = max_threads; |
|
while (threads_target / 2 >= max_target_length && threads_target > 1) { |
|
threads_target /= 2; |
|
} |
|
int threads_batch = std::min(max_threads / threads_target, (int) batch_size); |
|
dim3 block(threads_target, threads_batch); |
|
dim3 grid( |
|
std::max<int>( |
|
(max_target_length + threads_target - 1) / threads_target, 1), |
|
(batch_size + threads_batch - 1) / threads_batch, |
|
1); |
|
ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>> |
|
(grad.data_ptr<scalar_t>(), |
|
grad_out.data_ptr<scalar_t>(), grad_out.stride(0), |
|
log_alpha.data_ptr<scalar_t>(), log_beta.data_ptr<scalar_t>(), |
|
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1), |
|
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length, |
|
log_realvalues.data_ptr<scalar_t>(), |
|
neg_log_likelihood.data_ptr<scalar_t>(), |
|
grad.stride(0), grad.stride(1), grad.stride(2), |
|
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), |
|
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), |
|
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), |
|
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), |
|
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, |
|
batch_size, num_labels, sigma, BLANK, BLANK_1, zero_infinity); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} else { |
|
|
|
int threads_input = max_threads; |
|
while (threads_input / 2 >= log_probs.size(1) && threads_input > 1) { |
|
threads_input /= 2; |
|
} |
|
threads_batch = std::min(max_threads / threads_input, (int) batch_size); |
|
dim3 block(threads_input, threads_batch); |
|
dim3 grid((log_probs.size(1) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch); |
|
ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>> |
|
(grad.data_ptr<scalar_t>(), |
|
grad_out.data_ptr<scalar_t>(), grad_out.stride(0), |
|
log_alpha.data_ptr<scalar_t>(), log_beta.data_ptr<scalar_t>(), |
|
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1), |
|
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length, |
|
log_realvalues.data_ptr<scalar_t>(), |
|
neg_log_likelihood.data_ptr<scalar_t>(), |
|
grad.stride(0), grad.stride(1), grad.stride(2), |
|
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), |
|
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), |
|
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), |
|
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), |
|
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, |
|
batch_size, num_labels, sigma, BLANK, BLANK_1, zero_infinity); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
|
|
{ |
|
int threads_input = max_threads; |
|
while (threads_input / 2 >= log_probs.size(1) && threads_input > 1) { |
|
threads_input /= 2; |
|
} |
|
threads_input = 512; |
|
threads_batch = std::min(max_threads / threads_input, (int) batch_size); |
|
threads_batch = 1; |
|
|
|
|
|
dim3 block(threads_input, threads_batch); |
|
dim3 grid((log_probs.size(1) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch); |
|
ctc_loss_backward_collect_realvalue_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>> |
|
(grad_realval.data_ptr<scalar_t>(), |
|
grad_out.data_ptr<scalar_t>(), grad_out.stride(0), |
|
log_alpha.data_ptr<scalar_t>(), log_beta.data_ptr<scalar_t>(), |
|
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1), |
|
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length, |
|
realval.data_ptr<scalar_t>(), num_realvals, |
|
targets_realval.data_ptr<scalar_t>(), |
|
log_realvalues.data_ptr<scalar_t>(), |
|
neg_log_likelihood.data_ptr<scalar_t>(), |
|
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2), |
|
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2), |
|
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2), |
|
realval.stride(0), realval.stride(1), realval.stride(2), |
|
targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2), |
|
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2), |
|
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride, |
|
batch_size, num_labels, sigma, BLANK, BLANK_1); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
|
|
{ |
|
int threads_input = max_threads; |
|
while (threads_input / 2 >= log_probs.size(1)) { |
|
threads_input /= 2; |
|
} |
|
threads_batch = std::min(max_threads / threads_input, (int) batch_size); |
|
dim3 block(threads_input, threads_batch); |
|
dim3 grid( |
|
(log_probs.size(1) + threads_input-1)/threads_input, |
|
(batch_size+threads_batch-1)/threads_batch); |
|
ctc_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>( |
|
grad.data_ptr<scalar_t>(), |
|
input_lengths_t.data_ptr<int64_t>(), |
|
grad.stride(0), |
|
grad.stride(1), |
|
grad.stride(2), |
|
grad.size(0), |
|
grad.size(1), |
|
grad.size(2) |
|
); |
|
C10_CUDA_KERNEL_LAUNCH_CHECK(); |
|
} |
|
|
|
return std::make_tuple(grad, grad_realval); |
|
} |
|
|
|
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu( |
|
const Tensor& log_probs, |
|
const Tensor& targets, |
|
const Tensor& realval, |
|
const Tensor& targets_realval, |
|
IntArrayRef input_lengths, |
|
IntArrayRef target_lengths, |
|
double const sigma, |
|
int64_t BLANK, |
|
int64_t BLANK_1 |
|
) { |
|
return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "custom_ctc_loss_cuda", [&] { |
|
if (targets.scalar_type() == kLong) { |
|
return custom_ctc_loss_gpu_template<scalar_t, kLong>(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, static_cast<scalar_t>(sigma), BLANK, BLANK_1); |
|
} else { |
|
return custom_ctc_loss_gpu_template<scalar_t, kInt>(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, static_cast<scalar_t>(sigma), BLANK, BLANK_1); |
|
} |
|
}); |
|
} |
|
|
|
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu( |
|
const Tensor& grad, |
|
const Tensor& log_probs, |
|
const Tensor& targets, |
|
const Tensor& realval, |
|
const Tensor& targets_realval, |
|
IntArrayRef input_lengths, |
|
IntArrayRef target_lengths, |
|
const Tensor& neg_log_likelihood, |
|
const Tensor& log_alpha, |
|
double const sigma, |
|
int64_t BLANK, |
|
int64_t BLANK_1, |
|
bool zero_infinity |
|
) { |
|
|
|
|
|
globalContext().alertNotDeterministic("ctc_loss_backward_gpu"); |
|
return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "custom_ctc_loss_backward_cuda", [&] { |
|
if (targets.scalar_type() == kLong) { |
|
return custom_ctc_loss_backward_gpu_template<scalar_t, kLong>(grad, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, static_cast<scalar_t>(sigma), BLANK, BLANK_1, zero_infinity); |
|
} else { |
|
return custom_ctc_loss_backward_gpu_template<scalar_t, kInt>(grad, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, static_cast<scalar_t>(sigma), BLANK, BLANK_1, zero_infinity); |
|
} |
|
}); |
|
} |
|
|