#include #include #include #include #include #include #include /* Includes, cuda */ #include #include // BF16 Tensor core wrapper around cublas GEMMEx cublasStatus_t gemmex_wrapper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, at::BFloat16* A, int lda, at::BFloat16* B, int ldb, const float* beta, float* C, int ldc) { return cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16BF, lda, B, CUDA_R_16BF, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } // FP16 Tensor core wrapper around cublas GEMMEx cublasStatus_t gemmex_wrapper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, at::Half* A, int lda, at::Half* B, int ldb, const float* beta, float* C, int ldc) { return cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_16F, lda, B, CUDA_R_16F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } // FP32 Tensor core wrapper around cublas GEMMEx cublasStatus_t gemmex_wrapper( cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, int m, int n, int k, const float* alpha, float* A, int lda, float* B, int ldb, const float* beta, float* C, int ldc) { return cublasGemmEx( handle, transa, transb, m, n, k, alpha, A, CUDA_R_32F, lda, B, CUDA_R_32F, ldb, beta, C, CUDA_R_32F, ldc, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); } template int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) { cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cudaStream_t stream; cublasGetStream(handle, &stream); const float alpha = 1.0; const float beta = 1.0; int status = 1; status = gemmex_wrapper( handle, CUBLAS_OP_N, CUBLAS_OP_T, in_dim, out_dim, hidden_dim, &alpha, input, in_dim, d_output, out_dim, &beta, d_weight, in_dim); return status; } template int wgrad_gemm_accum_fp32_cuda(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); template int wgrad_gemm_accum_fp32_cuda(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim); template int wgrad_gemm_accum_fp32_cuda(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);