File size: 64,417 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 |
// Copyright (c) 2018 MathInf GmbH, Thomas Viehmann
// Modified by zyddnys
// Licensed under the BSD-3-Clause license
// This is the GPU implementation of the Connectionist Temporal Loss.
// We mostly follow Graves.
// 1. Graves et al: http://www.cs.toronto.edu/~graves/icml_2006.pdf
// Note from zyddnys:
// Added regression capability to CTC loss, currently we use L2 regression, future L1 regression maybe added
// Two BLANKS where BLANK is the BLANK in CTC, BLANK_1 means regression part of this target is ignored
// Many kernels are split into multiple kernels to prevent CUDA too much resources requested error
// We use the equations from above link, but note that [1] has 1-based indexing and we (of course) use 0-based.
// Graves et al call the probabilities y, we use log_probs (also calling them inputs)
// A few optimizations (similar to those here, but also some I didn't take) are described in
// 2. Minmin Sun: http://on-demand.gputechconf.com/gtc/2016/presentation/s6383-minmin-sun-speech-recognition.pdf
#include <torch/extension.h>
#include <ATen/TensorUtils.h>
#include <c10/util/Exception.h>
#include <c10/util/MathConstants.h>
#include <c10/macros/Macros.h>
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <THC/THCAtomics.cuh>
#include <type_traits>
#include <numeric>
using namespace c10;
using namespace at;
using namespace at::native;
// log P(x|mu)
template<typename scalar_t>
__device__ inline scalar_t custom_distance_forward_log(scalar_t x, scalar_t mu, scalar_t sigma) noexcept {
return -0.5 * std::log(2.0 * c10::pi<scalar_t>) - std::log(sigma) - 0.5 * (x - mu) * (x - mu) / (sigma * sigma);
}
// d(P(x|mu))/dmu
template<typename scalar_t>
__device__ inline scalar_t custom_distance_backward(scalar_t x, scalar_t mu, scalar_t sigma) noexcept {
scalar_t val = 1.0 / (sigma * std::sqrt(2 * c10::pi<scalar_t>)) * std::exp(-0.5 * (x - mu) * (x - mu) / (sigma * sigma));
return val * (x - mu) / (sigma * sigma);
}
// log P(x|mu)
template<typename scalar_t>
__device__ inline scalar_t custom_distance_forward_log_l1(scalar_t x, scalar_t mu, scalar_t sigma) noexcept {
return - std::log(2 * sigma) - std::abs(x - mu) / sigma;
}
template<typename scalar_t>
__device__ inline scalar_t sgn(scalar_t v) noexcept {
if (std::abs(v) < std::numeric_limits<scalar_t>::epsilon())
return 0;
return v / std::abs(v);
}
// d(P(x|mu))/dmu
template<typename scalar_t>
__device__ inline scalar_t custom_distance_backward_l1(scalar_t x, scalar_t mu, scalar_t sigma) noexcept {
return -sgn(mu - x) * std::exp(-std::abs(x - mu) / sigma) / (2 * sigma * sigma);
}
#if 0
// d(log P(x|mu))/dmu
template<typename scalar_t>
__device__ inline scalar_t custom_distance_backward_log(scalar_t x, scalar_t mu) {
return x - mu;
}
// P(x|mu)
template<typename scalar_t>
__device__ inline scalar_t custom_distance_forward(scalar_t x, scalar_t mu) {
return 0;
}
#endif
// this ad-hoc converts from targets (l in [1]) to augmented targets (l' in [1])
// so if l is l_0 l_1 ... l_(tl-1) then this looks up idx in
// l' = BLANK l_0 BLANK l_1 BLANK ... BLANK l_(tl-1) BLANK
// - note that no bound-checking is done
// - it is important to only call it with idx == 0 if the target length is 0
// - __restrict__ impact to be measured, see
// https://devblogs.nvidia.com/cuda-pro-tip-optimize-pointer-aliasing/
template <typename target_t>
__device__ static inline int64_t get_target_prime(
const target_t* __restrict__ target,
int64_t offset,
int64_t stride,
int64_t idx,
int64_t BLANK) {
if (idx % 2 == 0) {
return BLANK;
} else {
return target[offset + stride * (idx / 2)];
}
}
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_collect_log_realvalues_gpu_kernel(scalar_t* __restrict__ log_realvalues_data,
const int64_t* __restrict__ input_lengths,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths,
const scalar_t* __restrict__ realval_data, int64_t num_realval,
const scalar_t* __restrict__ targets_realval_data,
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride,
int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride,
int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime!
if (b >= batch_size)
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t rv_batch_offset = b*rv_batch_stride;
int64_t rvt_batch_offset = b*rvt_batch_stride;
int64_t lr_batch_offset = b*lr_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
if (s >= target_length)
return;
int64_t target = targets_data[tg_batch_offset + s * tg_target_stride];
for (int64_t t = 0; t < input_length; t++) {
scalar_t log_prod_n = 0;
if (target != BLANK && target != BLANK_1) {
for (int64_t i = 0; i < num_realval; ++i) {
log_prod_n += custom_distance_forward_log(
targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i],
realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i],
sigma
);
}
}
log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s] = log_prod_n;
}
}
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_log_alpha_gpu_kernel_phase1(scalar_t* __restrict__ log_alpha_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
const scalar_t* __restrict__ log_realvalues_data,
scalar_t* __restrict__ neg_log_likelihood_data,
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) {
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
// bookkeeping
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
if (b >= batch_size)
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b*lp_batch_stride;
int64_t la_batch_offset = b*la_batch_stride;
int64_t lr_batch_offset = b*lr_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
break;
case 1:
{
if (target_length != 0) {
int64_t tgt = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
1,
BLANK);
scalar_t cur_logprob = log_probs_data[lp_batch_offset + lp_char_stride * tgt];
//if (tgt != BLANK_1) {
cur_logprob += log_realvalues_data[lr_batch_offset + lr_input_stride * 0 + lr_target_stride * 0];
//}
la = cur_logprob;
} else {
la = neginf;
}
// la = target_length == 0 ? neginf
// : log_probs_data
// [lp_batch_offset +
// lp_char_stride *
// get_target_prime(
// targets_data,
// tg_batch_offset,
// tg_target_stride,
// 1,
// BLANK)];
}
break;
default:
la = neginf;
}
if (s < 2*max_target_length+1)
log_alpha_data[la_batch_offset + /* la_input_stride * 0 */ + la_target_stride * s] = la;
}
}
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_log_alpha_gpu_kernel_phase2(scalar_t* __restrict__ log_alpha_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
const scalar_t* __restrict__ log_realvalues_data,
scalar_t* __restrict__ neg_log_likelihood_data,
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) {
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
// bookkeeping
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
if (b >= batch_size)
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b*lp_batch_stride;
int64_t la_batch_offset = b*la_batch_stride;
int64_t lr_batch_offset = b*lr_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
// first row (t=0), the three equations for alpha_1 above eq (6)
for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t la;
switch (s) {
case 0:
la = log_probs_data[lp_batch_offset + lp_char_stride * BLANK];
break;
case 1:
{
if (target_length != 0) {
int64_t tgt = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
1,
BLANK);
scalar_t cur_logprob = log_probs_data[lp_batch_offset + lp_char_stride * tgt];
//if (tgt != BLANK_1) {
cur_logprob += log_realvalues_data[lr_batch_offset + lr_input_stride * 0 + lr_target_stride * 0];
//}
la = cur_logprob;
} else {
la = neginf;
}
// la = target_length == 0 ? neginf
// : log_probs_data
// [lp_batch_offset +
// lp_char_stride *
// get_target_prime(
// targets_data,
// tg_batch_offset,
// tg_target_stride,
// 1,
// BLANK)];
}
break;
default:
la = neginf;
}
if (s < 2*max_target_length+1)
log_alpha_data[la_batch_offset + /* la_input_stride * 0 */ + la_target_stride * s] = la;
}
for (int64_t block_s = 0; block_s < 2*max_target_length+1; block_s += blockDim.x) {
int64_t s = threadIdx.x + block_s;
// These two only depend on s, so we can cache them.
int64_t current_char; // l_s in eq (6)
bool have_three; // flag which of the two cases in eq (6) we have
if (s < 2 * target_length + 1 && target_length > 0) {
current_char = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s > 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s - 2,
BLANK) != current_char));
} else {
current_char = BLANK;
have_three = false;
}
for (int64_t t=1; t < max_input_length; t++) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
if ((t < input_length) && (s < 2 * target_length + 1)) {
scalar_t cur_logprob = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_char];
// if (current_char != BLANK_1 && current_char != BLANK) {
// for (int64_t i = 0; i < num_realval; ++i) {
// cur_logprob += custom_distance_forward_log(
// targets_realval_data[rvt_batch_offset + rvt_input_stride * (s / 2) + rvt_label_stride * i],
// realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i],
// sigma
// );
// }
// }
cur_logprob += (s % 2 == 1) ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0;
// only for valid t, s. This is equation (6) and (7), la1, la2, la3 are the three summands,
// lamax is the maximum for the logsumexp trick.
scalar_t la1 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * s];
scalar_t lamax = la1;
scalar_t la2, la3;
if (s > 0) {
la2 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-1)];
if (la2 > lamax)
lamax = la2;
} else {
la2 = neginf;
}
if (have_three) {
la3 = log_alpha_data[la_batch_offset + la_input_stride * (t-1) + la_target_stride * (s-2)];
if (la3 > lamax)
lamax = la3;
} else {
la3 = neginf;
}
if (lamax == neginf) // when all are neginf. (then the whole thing is neginf, but we can pretend)
lamax = 0;
log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = std::log(std::exp(la1-lamax)+std::exp(la2-lamax)+std::exp(la3-lamax))+lamax
+ cur_logprob;
} else {
// otherwise we just set to neginf
if (s < 2*max_target_length+1)
log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s] = neginf;
}
}
}
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
scalar_t l2 = target_length > 0
? log_alpha_data
[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
neg_log_likelihood_data[b] = -log_likelihood;
}
}
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_log_alpha_gpu_kernel_phase3(scalar_t* __restrict__ log_alpha_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
const scalar_t* __restrict__ realval_data, int64_t num_realval,
const scalar_t* __restrict__ targets_realval_data,
scalar_t* __restrict__ neg_log_likelihood_data,
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride,
int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) {
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
// bookkeeping
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
if (b >= batch_size)
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b*lp_batch_stride;
int64_t la_batch_offset = b*la_batch_stride;
int64_t rv_batch_offset = b*rv_batch_stride;
int64_t rvt_batch_offset = b*rvt_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
// __syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch
// compute the loss (eq (8))
if (threadIdx.x == 0) {
scalar_t l1 = log_alpha_data[la_batch_offset + la_input_stride * (input_length-1) + la_target_stride * (target_length*2)];
scalar_t l2 = target_length > 0
? log_alpha_data
[la_batch_offset + la_input_stride * (input_length - 1) +
la_target_stride * (target_length * 2 - 1)]
: neginf;
scalar_t m = ((l1 > l2) ? l1 : l2);
m = ((m == neginf) ? 0 : m);
scalar_t log_likelihood = std::log(std::exp(l1-m)+std::exp(l2-m))+m;
neg_log_likelihood_data[b] = -log_likelihood;
}
}
// The forward computation. Lot's of admin and a call to the alpha kernel.
// Note: we do not check that the labels are in the valid range. As we use
// them for indexing in the kernels, you'll see memory errors when you
// pass corrupt labels.
// We support both a 2-dimensional tensor as targets (one set of targets in each row) and
// a 1-dimensional tensor where all targets are concatenated (and we use target_lengths
// to figure out where they begin).
// We return log_alpha (currently, might change to (log_alpha+log_beta) to be passed to the
// backward. The dispatch function will only return the loss.
template<typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu_template(
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
scalar_t const sigma,
int64_t BLANK,
int64_t BLANK_1
) {
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
// log_probs: input_len x batch_size x num_labels
// targets [int64]: batch_size x target_length OR sum(target_lengths)
// realval [float]: batch_size x input_len x num_realval
// targets_realval [float]: batch_size x max_target_length x num_realval
CheckedFrom c = "custom_ctc_loss_gpu";
using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
auto log_probs_arg = TensorArg(log_probs, "log_probs", 1);
auto targets_arg = TensorArg(targets, "targets", 2);
auto realval_arg = TensorArg(realval, "realval", 3);
auto targets_realval_arg = TensorArg(targets_realval, "targets_realval", 4);
checkAllSameGPU(c, {log_probs_arg, targets_arg, realval_arg, targets_realval_arg});
checkScalarType(c, targets_arg, target_scalar_type);
checkDim(c, log_probs_arg, 3);
checkDim(c, realval_arg, 3);
checkDim(c, targets_realval_arg, 3);
checkDimRange(c, targets_arg, 1, 3);
int64_t batch_size = log_probs.size(0);
int64_t num_realvals = realval.size(2);
int64_t num_labels = log_probs.size(2);
TORCH_CHECK((0 <= BLANK) && (BLANK < num_labels), "blank must be in label range");
TORCH_CHECK((0 <= BLANK_1) && (BLANK_1 < num_labels), "blank1 must be in label range");
TORCH_CHECK(input_lengths.size() == batch_size, "input_lengths must be of size batch_size");
TORCH_CHECK(realval.size(2) == targets_realval.size(2), "number of real values must be the same for both realval and targets_realval");
TORCH_CHECK(log_probs.size(1) == realval.size(1), "input_lengths must be the same for both log_probs and realval");
TORCH_CHECK(target_lengths.size() == batch_size, "target_lengths must be of size batch_size");
int64_t lp_input_stride = log_probs.stride(1);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length = 0;
auto tg_batch_offsets = at::empty({batch_size}, at::device(at::kCPU).dtype(at::kLong));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(0);
checkSize(c, targets_arg, 0, pos);
}
else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(1);
checkSize(c, targets_arg, 0, batch_size);
TORCH_CHECK(targets.size(1) >= max_target_length,
"Expected tensor to have size at least ", max_target_length, " at dimension 1, but got size ", targets.size(1), " for ", targets_arg,
" (while checking arguments for ", c, ")");
}
int64_t max_input_length = log_probs.size(1);
for (int64_t b = 0; b < batch_size; b++) {
TORCH_CHECK(input_lengths[b] <= max_input_length,
"Expected input_lengths to have value at most ", max_input_length, ", but got value ", input_lengths[b],
" (while checking arguments for ", c, ")");
}
auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_realvalues = at::zeros({batch_size, log_probs.size(1), std::max(max_target_length, int64_t(1))}, log_probs.options());
Tensor log_alpha = at::empty({batch_size, log_probs.size(1), 2*max_target_length+1}, log_probs.options());
Tensor neg_log_likelihood = at::empty({batch_size}, log_probs.options());
log_alpha.fill_(neginf);
constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 896; // we need 72 or so 32 bit registers for double
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
{
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid(
std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch,
1);
ctc_loss_collect_log_realvalues_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(log_realvalues.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
realval.data_ptr<scalar_t>(), num_realvals,
targets_realval.data_ptr<scalar_t>(),
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
realval.stride(0), realval.stride(1), realval.stride(2),
targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
batch_size, num_labels, sigma, BLANK, BLANK_1);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// Very likely, we could be more clever here, e.g. learning (or genralizing and reusing) from SoftMax.cu...
int threads_target = max_threads;
while (threads_target / 2 >= 2*max_target_length+1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
// ctc_loss_log_alpha_gpu_kernel_phase1<scalar_t, target_t><<<grid, block, 0, stream>>>(
// log_alpha.data_ptr<scalar_t>(),
// log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1),
// targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length,
// log_realvalues.data_ptr<scalar_t>(),
// neg_log_likelihood.data_ptr<scalar_t>(),
// log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
// log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
// log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
// tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
// batch_size, sigma, BLANK, BLANK_1);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
ctc_loss_log_alpha_gpu_kernel_phase2<scalar_t, target_t><<<grid, block, 0, stream>>>(
log_alpha.data_ptr<scalar_t>(),
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length,
log_realvalues.data_ptr<scalar_t>(),
neg_log_likelihood.data_ptr<scalar_t>(),
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
batch_size, sigma, BLANK, BLANK_1);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// ctc_loss_log_alpha_gpu_kernel_phase3<scalar_t, target_t><<<grid, block, 0, stream>>>(
// log_alpha.data_ptr<scalar_t>(),
// log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1),
// targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length,
// realval.data_ptr<scalar_t>(), num_realvals,
// targets_realval.data_ptr<scalar_t>(),
// neg_log_likelihood.data_ptr<scalar_t>(),
// log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
// log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
// realval.stride(0), realval.stride(1), realval.stride(2),
// targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2),
// tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
// batch_size, sigma, BLANK, BLANK_1);
// C10_CUDA_KERNEL_LAUNCH_CHECK();
return std::make_tuple(neg_log_likelihood, log_alpha);
}
// The second (backward) half of the forward backward algorithm, (10) and (11). This is parallel to the
// alpha kernel above. (As mentioned above, it might make sense do the calculation in the alpha kernel.)
template<typename scalar_t, typename target_t>
__global__ void
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
ctc_loss_backward_log_beta_gpu_kernel(scalar_t* __restrict__ log_beta_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
const scalar_t* __restrict__ log_realvalues_data,
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride,
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) {
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
if (b >= batch_size)
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b*lp_batch_stride;
int64_t lb_batch_offset = b*lb_batch_stride;
int64_t lr_batch_offset = b*lr_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
// "first" row, the beta initiaization before eq (10) (t=target_length - differs per batch)
for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) {
int64_t s = threadIdx.x + block_s;
scalar_t lb;
if (s == 2*target_length) {
lb = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * BLANK];
} else if (s == 2 * target_length - 1) { // false for target_length == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
scalar_t cur_logprob = log_probs_data[lp_batch_offset + (input_length-1) * lp_input_stride + lp_char_stride * current_target_prime];
lb = cur_logprob + log_realvalues_data[lr_batch_offset + lr_input_stride * (input_length - 1) + lr_target_stride * (target_length - 1)];
} else {
lb = neginf;
}
if (s < 2*max_target_length+1) {
log_beta_data[lb_batch_offset + (input_length-1) * lb_input_stride + lb_target_stride * s] = lb;
}
}
// go backward in s
for (int64_t block_s = 2*max_target_length - (2*max_target_length % blockDim.x); block_s >= 0; block_s -= blockDim.x) {
int64_t s = threadIdx.x + block_s;
int64_t current_target_prime;
bool have_three;
if (s < 2 * target_length + 1 && target_length > 0) {
current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
have_three =
((s < 2 * target_length - 1) &&
(get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s + 2,
BLANK) != current_target_prime));
} else {
current_target_prime = BLANK;
have_three = false;
}
// now go backward in t. Note that we need to skip the last timestep that we did above.
for (int64_t t=max_input_length-2; t>=0; t--) {
__syncthreads(); // on cuda 9 we might use partial synchronization of only the threads within the same batch item
if ((t < input_length - 1) && (s < 2 * target_length + 1)) {
scalar_t cur_logprob = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
cur_logprob += (s % 2 == 1) ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0;
scalar_t lb1 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * s];
scalar_t lbmax = lb1;
scalar_t lb2, lb3;
if (s < 2*target_length) {
lb2 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+1)];
if (lb2 > lbmax)
lbmax = lb2;
} else {
lb2 = neginf;
}
if (have_three) {
lb3 = log_beta_data[lb_batch_offset + lb_input_stride * (t+1) + lb_target_stride * (s+2)];
if (lb3 > lbmax)
lbmax = lb3;
} else {
lb3 = neginf;
}
if (lbmax == neginf)
lbmax = 0;
scalar_t lb = std::log(std::exp(lb1-lbmax)+std::exp(lb2-lbmax)+std::exp(lb3-lbmax))+lbmax
+ cur_logprob;
log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] = lb;
}
else if (
(s < 2 * max_target_length + 1) &&
(((target_length == 0) && (s > 0)) || (s >= 2 * target_length + 1) ||
(t >= input_length))) {
log_beta_data
[lb_batch_offset + lb_input_stride * t + lb_target_stride * s] =
neginf;
}
}
}
}
// This implements the subtrahend of equation (16) for all *nonblank* characters.
// It assumes you have probs in gradient_data when called
// and it modifies gradient_data to be, the gradient.
// In order to facilitate this inplace update, We don't actually do this in logspace.
// (The other variant implemented uses log_space and the differences seem to be
// not so problematic at least with unit normal distributed test activations.)
// Internally this uses atomicAdd because different threads may write to the same
// gradient position.
// This is parallelised over b and s again.
// Note that for us, the Z of eqn (16) is actually constant for all t and it is the
// likelihood - this is why we use the negative log likelihood below.
// We also multiply by the input gradient to keep with standard autograd style.
// I took this trick from [2], for moderate alphabet sizes a log-space
// calculation (with an atomic log add) is similarly in performance, but for large
// alphabets the inplace nature is a considerable advantage.
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_backward_collect_nonblank_gpu_kernel(scalar_t* __restrict__ gradient_data,
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
const scalar_t* __restrict__ log_realvalues_data,
const scalar_t* __restrict__ neg_log_likelihood_data,
int64_t gr_batch_stride, int64_t gr_input_stride, int64_t gr_char_stride,
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t s = threadIdx.x + blockIdx.x * blockDim.x; // note, this directly indexes into targets, not targets prime!
if (b >= batch_size)
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t gr_batch_offset = b*gr_batch_stride;
int64_t lp_batch_offset = b*lp_batch_stride;
int64_t la_batch_offset = b*la_batch_stride;
int64_t lb_batch_offset = b*lb_batch_stride;
int64_t lr_batch_offset = b*lr_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
if (s >= target_length)
return;
int64_t target = targets_data[tg_batch_offset + s * tg_target_stride];
scalar_t nll = neg_log_likelihood_data[b];
scalar_t gr = grad_out_data[b * grad_out_batch_stride];
if (zero_infinity && nll == std::numeric_limits<scalar_t>::infinity())
return;
for (int64_t t = 0; t < input_length; t++) {
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * target];
scalar_t log_alpha_beta = log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s*2+1)] + log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s*2+1)];
scalar_t log_prod_n = log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s];
scalar_t log_alpha_beta_div_pr = log_alpha_beta - log_prod_n;
gpuAtomicAddNoReturn(&gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * target],
-std::exp(log_alpha_beta_div_pr + nll - lp) * gr);
}
}
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_backward_collect_realvalue_gpu_kernel(scalar_t* __restrict__ gradient_realval_data,
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
const scalar_t* log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
const scalar_t* __restrict__ realval_data, int64_t num_realval,
const scalar_t* __restrict__ targets_realval_data,
const scalar_t* __restrict__ log_realvalues_data,
const scalar_t* __restrict__ neg_log_likelihood_data,
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
int64_t rv_batch_stride, int64_t rv_input_stride, int64_t rv_label_stride,
int64_t rvt_batch_stride, int64_t rvt_input_stride, int64_t rvt_label_stride,
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1) {
//constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
if ((t >= max_input_length) || (b >= batch_size))
return;
//int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t lp_batch_offset = b*lp_batch_stride;
int64_t la_batch_offset = b*la_batch_stride;
int64_t lb_batch_offset = b*lb_batch_stride;
int64_t rv_batch_offset = b*rv_batch_stride;
int64_t rvt_batch_offset = b*rvt_batch_stride;
int64_t lr_batch_offset = b*lr_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
scalar_t nll = neg_log_likelihood_data[b];
scalar_t gr = grad_out_data[b * grad_out_batch_stride];
// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
for (int s = 0; s < max_target_length; s++) {
if (s < target_length) {
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s * 2 + 1,
BLANK);
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * current_target_prime];
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * (s * 2 + 1)]
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * (s * 2 + 1)]);
scalar_t log_prod_n = log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * s];
if (current_target_prime != BLANK && current_target_prime != BLANK_1) {
scalar_t log_term1 = log_alpha_beta - lp - 2 * log_prod_n;
for (int64_t i = 0; i != num_realval; ++i) {
scalar_t log_constant_factors = log_prod_n - custom_distance_forward_log(
targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i],
realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i],
static_cast<scalar_t>(sigma)
);
scalar_t grad_dp_dmu = std::exp(log_term1 + log_constant_factors + nll) * custom_distance_backward(
targets_realval_data[rvt_batch_offset + rvt_input_stride * s + rvt_label_stride * i],
realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i],
static_cast<scalar_t>(sigma)
);
gradient_realval_data[rv_batch_offset + rv_input_stride * t + rv_label_stride * i] += -grad_dp_dmu * gr;
}
}
}
}
}
// This is the naive implementation of equation (16). It is parallelised in batch and input timestep.
// It appears to be faster than the above method for small batch sizes.
template<typename scalar_t, typename target_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_backward_collect_gpu_kernel(scalar_t* __restrict__ gradient_data,
const scalar_t* __restrict__ grad_out_data, int64_t grad_out_batch_stride,
const scalar_t* __restrict__ log_alpha_data, const scalar_t* __restrict__ log_beta_data,
const scalar_t*log_probs_data, const int64_t* __restrict__ input_lengths, int64_t max_input_length,
const target_t* __restrict__ targets_data, const int64_t* __restrict__ target_lengths, int64_t max_target_length,
const scalar_t* __restrict__ log_realvalues_data,
const scalar_t* __restrict__ neg_log_likelihood_data,
int64_t gr_batch_stride, int64_t gr_input_stride, int64_t gr_char_stride,
int64_t lp_batch_stride, int64_t lp_input_stride, int64_t lp_char_stride,
int64_t la_batch_stride, int64_t la_input_stride, int64_t la_target_stride,
int64_t lb_batch_stride, int64_t lb_input_stride, int64_t lb_target_stride,
int64_t lr_batch_stride, int64_t lr_input_stride, int64_t lr_target_stride,
const int64_t* __restrict__ tg_batch_offsets, int64_t tg_target_stride,
int64_t batch_size, int64_t num_labels, scalar_t sigma, int64_t BLANK, int64_t BLANK_1, bool zero_infinity) {
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
if ((t >= max_input_length) || (b >= batch_size))
return;
int64_t input_length = input_lengths[b];
int64_t target_length = target_lengths[b];
int64_t gr_batch_offset = b*gr_batch_stride;
int64_t lp_batch_offset = b*lp_batch_stride;
int64_t la_batch_offset = b*la_batch_stride;
int64_t lb_batch_offset = b*lb_batch_stride;
int64_t lr_batch_offset = b*lr_batch_stride;
int64_t tg_batch_offset = tg_batch_offsets[b];
// collected[b, t, target'[s]] "log+=" log_alpha[t, s]+log_beta[t, s]
for (int s = 0; s < 2*max_target_length+1; s++) {
if (s < 2 * target_length + 1) { // if target_length == 0, s == 0
int64_t current_target_prime = get_target_prime(
targets_data,
tg_batch_offset,
tg_target_stride,
s,
BLANK);
scalar_t log_alpha_beta = (log_alpha_data[la_batch_offset + la_input_stride * t + la_target_stride * s]
+ log_beta_data[lb_batch_offset + lb_input_stride * t + lb_target_stride * s]);
scalar_t log_prod_n = s % 2 == 1 ? log_realvalues_data[lr_batch_offset + lr_input_stride * t + lr_target_stride * (s / 2)] : 0;
scalar_t log_alpha_beta_div_pr = log_alpha_beta - log_prod_n;
scalar_t& lcab = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * current_target_prime];
if (lcab == neginf) {
lcab = log_alpha_beta_div_pr;
} else {
scalar_t max = ((lcab > log_alpha_beta_div_pr) ? lcab : log_alpha_beta_div_pr);
lcab = std::log(std::exp(lcab-max)+std::exp(log_alpha_beta_div_pr-max))+max;
}
}
}
scalar_t nll = neg_log_likelihood_data[b];
scalar_t gr = grad_out_data[b * grad_out_batch_stride];
for (int64_t c = 0; c < num_labels; c++) {
scalar_t& res = gradient_data[gr_batch_offset + t * gr_input_stride + gr_char_stride * c];
if (t < input_length && (! zero_infinity || nll != std::numeric_limits<scalar_t>::infinity())) {
scalar_t lp = log_probs_data[lp_batch_offset + t * lp_input_stride + lp_char_stride * c];
res = (std::exp(lp)-std::exp(res + nll - lp)) * gr;
}
else {
res = 0.;
}
}
}
// This is to zero gradients which corresponding to the out-of-sequence position
// Those gradients should not be used in any model update since the input
// elements are padded
template<typename scalar_t>
__global__ void
#if defined (__HIP_PLATFORM_HCC__)
C10_LAUNCH_BOUNDS_2((std::is_same<scalar_t, float>::value ? 1024 : 896), 1)
#endif
ctc_loss_zero_padded_gradients(
scalar_t* __restrict__ gradient_data, /* (T, B, D) layout */
const int64_t* __restrict__ input_lengths, /* (B, ) layout */
int64_t gr_batch_stride,
int64_t gr_timestep_stride,
int64_t gr_label_stride,
int64_t batch_size, /* B */
int64_t max_input_length, /* T */
int64_t num_labels /* D */
) {
int64_t b = threadIdx.y + blockIdx.y * blockDim.y;
int64_t t = threadIdx.x + blockIdx.x * blockDim.x;
if (b >= batch_size || t >= max_input_length) {
return;
}
scalar_t input_length = input_lengths[b];
if (t >= input_length) {
for (int l = 0; l < num_labels; l++)
gradient_data[
b * gr_batch_stride + t * gr_timestep_stride + l * gr_label_stride]
= 0.0f;
}
}
// The backward. It essentially computes eq 16 by using the above kernels.
// We don't do a lot of checking as we envision this to be called only when backpropagating through a (well-checked) forward.
template<typename scalar_t, ScalarType target_scalar_type>
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu_template(
const Tensor& grad_out,
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
const Tensor& neg_log_likelihood,
const Tensor& log_alpha,
scalar_t const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
) {
constexpr scalar_t neginf = -std::numeric_limits<scalar_t>::infinity();
using target_t = typename std::conditional<target_scalar_type == kInt, int, int64_t>::type;
int64_t batch_size = log_probs.size(0);
int64_t num_realvals = realval.size(2);
int64_t num_labels = log_probs.size(2);
int64_t lp_input_stride = log_probs.stride(1);
int64_t lp_char_stride = log_probs.stride(2);
int64_t tg_target_stride;
int64_t max_target_length;
auto tg_batch_offsets = at::empty({batch_size}, TensorOptions(at::CPU(kLong)));
auto tg_batch_offsets_data = tg_batch_offsets.data_ptr<int64_t>();
if (targets.dim() == 1) { // concatenated targets
int64_t pos = 0;
max_target_length = 0;
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = pos;
pos += target_lengths[i];
if (max_target_length < target_lengths[i])
max_target_length = target_lengths[i];
}
tg_target_stride = targets.stride(0);
}
else { // batch x max_target_length
// dim is 2
int64_t tg_batch_stride = targets.stride(0);
for (int64_t i = 0; i < batch_size; i++) {
tg_batch_offsets_data[i] = i * tg_batch_stride;
}
tg_target_stride = targets.stride(1);
max_target_length = log_alpha.size(2)/2; // targets.size(1) might be larger
}
auto target_lengths_t = at::tensor(target_lengths, targets.options().dtype(kLong));
auto input_lengths_t = at::tensor(input_lengths, targets.options().dtype(kLong));
tg_batch_offsets = tg_batch_offsets.cuda();
Tensor log_realvalues = at::zeros({batch_size, log_probs.size(1), std::max(max_target_length, int64_t(1))}, log_alpha.options());
Tensor log_beta = at::empty_like(log_alpha, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
log_beta.fill_(neginf);
Tensor grad = at::full_like(log_probs, neginf, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for log(sum (alpha beta))
Tensor grad_realval = at::full_like(realval, 0, LEGACY_CONTIGUOUS_MEMORY_FORMAT); // initialization for sum (d realvalue)
// As above, there may be better configurations to use.
constexpr int max_threads = std::is_same<scalar_t, float>::value ? 1024 : 896; // we need 72 or so 32 bit registers for double
int threads_target = max_threads;
while (threads_target / 2 >= 2*max_target_length+1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
{
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid(
std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch,
1);
ctc_loss_collect_log_realvalues_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(log_realvalues.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(),
realval.data_ptr<scalar_t>(), num_realvals,
targets_realval.data_ptr<scalar_t>(),
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
realval.stride(0), realval.stride(1), realval.stride(2),
targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
batch_size, num_labels, sigma, BLANK, BLANK_1);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
{
dim3 block(threads_target, threads_batch);
dim3 grid((2*max_target_length+1 + threads_target-1)/threads_target, (batch_size+threads_batch-1)/threads_batch);
ctc_loss_backward_log_beta_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(log_beta.data_ptr<scalar_t>(),
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length,
log_realvalues.data_ptr<scalar_t>(),
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
batch_size, sigma, BLANK, BLANK_1);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
// Very crude heuristic for what is a small problem., based on linearly regressing problem dimensions on
// the (capped) difference of timings.
// Note that for OK problems target length <= input length, so we
// only consider input length.
bool is_large = (2*log_probs.size(1)+(24*batch_size)/10+(2*num_labels)/10) > 450;
if (is_large) { // large alphabet, large batch
// this computes the probs, minuend in (16)
at::exp_out(grad, log_probs);
// now we compute the subtrahend for the blanks. It is a straightforward reduction because we know that
// blanks are in every other position.
// maybe we should kernelize this, too.
auto grad_blank = grad.narrow(2, BLANK, 1);
grad_blank -= (at::logsumexp(log_alpha.as_strided({batch_size, log_alpha.size(1), max_target_length+1},
{log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2)*2})
+ log_beta.as_strided({batch_size, log_beta.size(1), max_target_length+1},
{log_beta.stride(0), log_beta.stride(1), log_beta.stride(2)*2}),
2, true)
.add_(neg_log_likelihood.view({batch_size, 1, 1}))
.sub_(log_probs.narrow(2, BLANK, 1))
.exp_()
);
// scale by output gradient (blanks and first summand of non-blanks)
grad *= grad_out.view({batch_size, 1, 1});
if (zero_infinity) {
grad = at::where(neg_log_likelihood.view({batch_size, 1, 1}) == Scalar(std::numeric_limits<scalar_t>::infinity()), at::zeros({}, grad.options()), grad);
}
// For the non-blank characters, we use a kernel to compute the subtrahend.
// Again we might configure block and grid in a better way.
int threads_target = max_threads;
while (threads_target / 2 >= max_target_length && threads_target > 1) {
threads_target /= 2;
}
int threads_batch = std::min(max_threads / threads_target, (int) batch_size);
dim3 block(threads_target, threads_batch);
dim3 grid(
std::max<int>(
(max_target_length + threads_target - 1) / threads_target, 1),
(batch_size + threads_batch - 1) / threads_batch,
1);
ctc_loss_backward_collect_nonblank_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data_ptr<scalar_t>(),
grad_out.data_ptr<scalar_t>(), grad_out.stride(0),
log_alpha.data_ptr<scalar_t>(), log_beta.data_ptr<scalar_t>(),
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length,
log_realvalues.data_ptr<scalar_t>(),
neg_log_likelihood.data_ptr<scalar_t>(),
grad.stride(0), grad.stride(1), grad.stride(2),
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
batch_size, num_labels, sigma, BLANK, BLANK_1, zero_infinity);
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else { // small problem, use naive algorithm
// Still no block/grid configuration guru...
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(1) && threads_input > 1) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(1) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);
ctc_loss_backward_collect_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad.data_ptr<scalar_t>(),
grad_out.data_ptr<scalar_t>(), grad_out.stride(0),
log_alpha.data_ptr<scalar_t>(), log_beta.data_ptr<scalar_t>(),
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length,
log_realvalues.data_ptr<scalar_t>(),
neg_log_likelihood.data_ptr<scalar_t>(),
grad.stride(0), grad.stride(1), grad.stride(2),
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
batch_size, num_labels, sigma, BLANK, BLANK_1, zero_infinity);
C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors
}
// collect real value gradients
{
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(1) && threads_input > 1) {
threads_input /= 2;
}
threads_input = 512;
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
threads_batch = 1;
//threads_batch = threads_batch >> 4;
//std::cout << "threads_input=" << threads_input << ",threads_batch=" << threads_batch << "\n";
dim3 block(threads_input, threads_batch);
dim3 grid((log_probs.size(1) + threads_input-1)/threads_input, (batch_size+threads_batch-1)/threads_batch);
ctc_loss_backward_collect_realvalue_gpu_kernel<scalar_t, target_t><<<grid, block, 0, stream>>>
(grad_realval.data_ptr<scalar_t>(),
grad_out.data_ptr<scalar_t>(), grad_out.stride(0),
log_alpha.data_ptr<scalar_t>(), log_beta.data_ptr<scalar_t>(),
log_probs.data_ptr<scalar_t>(), input_lengths_t.data_ptr<int64_t>(), log_probs.size(1),
targets.data_ptr<target_t>(), target_lengths_t.data_ptr<int64_t>(), max_target_length,
realval.data_ptr<scalar_t>(), num_realvals,
targets_realval.data_ptr<scalar_t>(),
log_realvalues.data_ptr<scalar_t>(),
neg_log_likelihood.data_ptr<scalar_t>(),
log_probs.stride(0), log_probs.stride(1), log_probs.stride(2),
log_alpha.stride(0), log_alpha.stride(1), log_alpha.stride(2),
log_beta.stride(0), log_beta.stride(1), log_beta.stride(2),
realval.stride(0), realval.stride(1), realval.stride(2),
targets_realval.stride(0), targets_realval.stride(1), targets_realval.stride(2),
log_realvalues.stride(0), log_realvalues.stride(1), log_realvalues.stride(2),
tg_batch_offsets.data_ptr<int64_t>(), tg_target_stride,
batch_size, num_labels, sigma, BLANK, BLANK_1);
C10_CUDA_KERNEL_LAUNCH_CHECK(); // catch launch errors
}
// zero those invalid gradient elements due to padding
{
int threads_input = max_threads;
while (threads_input / 2 >= log_probs.size(1)) {
threads_input /= 2;
}
threads_batch = std::min(max_threads / threads_input, (int) batch_size);
dim3 block(threads_input, threads_batch);
dim3 grid(
(log_probs.size(1) + threads_input-1)/threads_input,
(batch_size+threads_batch-1)/threads_batch);
ctc_loss_zero_padded_gradients<scalar_t><<<grid, block, 0, stream>>>(
grad.data_ptr<scalar_t>(),
input_lengths_t.data_ptr<int64_t>(),
grad.stride(0),
grad.stride(1),
grad.stride(2),
grad.size(0),
grad.size(1),
grad.size(2)
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
return std::make_tuple(grad, grad_realval);
}
std::tuple<Tensor, Tensor> custom_ctc_loss_gpu(
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
double const sigma,
int64_t BLANK,
int64_t BLANK_1
) {
return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "custom_ctc_loss_cuda", [&] {
if (targets.scalar_type() == kLong) {
return custom_ctc_loss_gpu_template<scalar_t, kLong>(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, static_cast<scalar_t>(sigma), BLANK, BLANK_1);
} else {
return custom_ctc_loss_gpu_template<scalar_t, kInt>(log_probs, targets, realval, targets_realval, input_lengths, target_lengths, static_cast<scalar_t>(sigma), BLANK, BLANK_1);
}
});
}
std::tuple<Tensor, Tensor> custom_ctc_loss_backward_gpu(
const Tensor& grad,
const Tensor& log_probs,
const Tensor& targets,
const Tensor& realval,
const Tensor& targets_realval,
IntArrayRef input_lengths,
IntArrayRef target_lengths,
const Tensor& neg_log_likelihood,
const Tensor& log_alpha,
double const sigma,
int64_t BLANK,
int64_t BLANK_1,
bool zero_infinity
) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("ctc_loss_backward_gpu");
return AT_DISPATCH_FLOATING_TYPES(log_probs.scalar_type(), "custom_ctc_loss_backward_cuda", [&] {
if (targets.scalar_type() == kLong) {
return custom_ctc_loss_backward_gpu_template<scalar_t, kLong>(grad, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, static_cast<scalar_t>(sigma), BLANK, BLANK_1, zero_infinity);
} else {
return custom_ctc_loss_backward_gpu_template<scalar_t, kInt>(grad, log_probs, targets, realval, targets_realval, input_lengths, target_lengths, neg_log_likelihood, log_alpha, static_cast<scalar_t>(sigma), BLANK, BLANK_1, zero_infinity);
}
});
}
|