| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/core/Tensor.h> |
| #include <ATen/AccumulateType.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/TensorUtils.h> |
| #include <ATen/cuda/CUDAContext.h> |
| #include <ATen/cuda/CUDAApplyUtils.cuh> |
| #include <c10/macros/Macros.h> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/empty_like.h> |
| #include <ATen/ops/_thnn_fused_lstm_cell_native.h> |
| #include <ATen/ops/_thnn_fused_lstm_cell_backward_impl_native.h> |
| #include <ATen/ops/_thnn_fused_gru_cell_native.h> |
| #include <ATen/ops/_thnn_fused_gru_cell_backward_native.h> |
| #endif |
| |
| namespace at::native { |
| |
| namespace { |
| |
| using at::cuda::detail::TensorInfo; |
| using at::cuda::detail::getTensorInfo; |
| using at::cuda::detail::IndexToOffset; |
| using at::cuda::detail::canUse32BitIndexMath; |
| |
| // Factor will be 3 for GRU and 4 for LSTM |
| void checkSizes(CheckedFrom c, |
| const TensorArg& input_gates, const TensorArg& hidden_gates, |
| const TensorArg& input_bias, const TensorArg& hidden_bias, |
| int64_t factor, const TensorArg& prev_hidden) { |
| checkDim(c, input_gates, 2); |
| checkSameSize(c, input_gates, hidden_gates); |
| int64_t gates_size = input_gates->size(1); |
| |
| if (input_bias->defined()) { |
| checkDim(c, input_bias, 1); |
| checkNumel(c, input_bias, gates_size); |
| checkSameSize(c, input_bias, hidden_bias); |
| } |
| |
| checkDim(c, prev_hidden, 2); |
| checkNumel(c, prev_hidden, input_gates->size(0) * gates_size / factor); |
| |
| checkAllSameGPU(c, {input_gates, hidden_gates, input_bias, hidden_bias, prev_hidden}); |
| } |
| |
| bool allContiguous(at::TensorList tensors) { |
| return std::all_of(tensors.begin(), tensors.end(), |
| [](const at::Tensor& t) { return !t.defined() || t.is_contiguous(); }); |
| } |
| |
| void getLaunchConfig(dim3* block, dim3* grid, int64_t numel) { |
| c10::DeviceIndex curDevice = -1; |
| c10::cuda::GetDevice(&curDevice); |
| *block = cuda::getApplyBlock(); |
| TORCH_INTERNAL_ASSERT(cuda::getApplyGrid(numel, *grid, curDevice), |
| "Could not get grid size for pointwise apply."); |
| } |
| |
| template<typename T, typename T2> |
| TensorInfo<T, T2> tryGetTensorInfo(const at::Tensor& t) { |
| return t.defined() ? getTensorInfo<T, T2>(t) : TensorInfo<T, T2>{}; |
| } |
| |
| void collapseDims() {}; |
| template<typename T, typename T2, typename... Args> |
| void collapseDims(TensorInfo<T, T2>& info, Args&... infos) { |
| info.collapseDims(); |
| collapseDims(infos...); |
| } |
| |
| #define DEVICE_LINEAR_GET(D_TENSOR, INDEX) \ |
| D_TENSOR.data[IndexToOffset<scalar_t, index_type, indexing_kind>::get(INDEX, D_TENSOR)] |
| |
| // Biases are always 1D |
| #define DEVICE_BIAS_GET(D_TENSOR, INDEX) \ |
| D_TENSOR.data[IndexToOffset<scalar_t, index_type, 1>::get(INDEX, D_TENSOR)] |
| |
| #define H2F(input) static_cast<accscalar_t>(input) |
| #define F2H(input) static_cast<scalar_t>(input) |
| |
| template<typename T> |
| __device__ __forceinline__ |
| T sigmoid(T in) { |
| T one = static_cast<T>(1.0); |
| return one / (one + ::exp(-in)); |
| } |
| |
| namespace kernel { |
| |
| template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind> |
| #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) |
| C10_LAUNCH_BOUNDS_2(512, 4) |
| #endif |
| __global__ void lstm_cell_forward( |
| TensorInfo<scalar_t, index_type> input, |
| TensorInfo<scalar_t, index_type> hidden, |
| TensorInfo<scalar_t, index_type> bias1, |
| TensorInfo<scalar_t, index_type> bias2, |
| TensorInfo<scalar_t, index_type> _cx, |
| TensorInfo<scalar_t, index_type> _hy, |
| TensorInfo<scalar_t, index_type> _cy, |
| TensorInfo<scalar_t, index_type> workspace, |
| index_type hsz, |
| index_type totalElements) { |
| bool has_bias = bias1.data != nullptr; |
| for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x; |
| linearIndex < totalElements; |
| linearIndex += gridDim.x * blockDim.x) { |
| index_type offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz; |
| |
| scalar_t iig = DEVICE_LINEAR_GET(input, offset+0*hsz); |
| scalar_t ifg = DEVICE_LINEAR_GET(input, offset+1*hsz); |
| scalar_t icg = DEVICE_LINEAR_GET(input, offset+2*hsz); |
| scalar_t iog = DEVICE_LINEAR_GET(input, offset+3*hsz); |
| |
| scalar_t hig = DEVICE_LINEAR_GET(hidden, offset+0*hsz); |
| scalar_t hfg = DEVICE_LINEAR_GET(hidden, offset+1*hsz); |
| scalar_t hcg = DEVICE_LINEAR_GET(hidden, offset+2*hsz); |
| scalar_t hog = DEVICE_LINEAR_GET(hidden, offset+3*hsz); |
| |
| scalar_t* wig = &DEVICE_LINEAR_GET(workspace, offset+0*hsz); |
| scalar_t* wfg = &DEVICE_LINEAR_GET(workspace, offset+1*hsz); |
| scalar_t* wcg = &DEVICE_LINEAR_GET(workspace, offset+2*hsz); |
| scalar_t* wog = &DEVICE_LINEAR_GET(workspace, offset+3*hsz); |
| |
| scalar_t cx = DEVICE_LINEAR_GET(_cx, linearIndex); |
| |
| scalar_t* hy = &DEVICE_LINEAR_GET(_hy, linearIndex); |
| scalar_t* cy = &DEVICE_LINEAR_GET(_cy, linearIndex); |
| |
| scalar_t b1i, b1f, b1c, b1o; |
| scalar_t b2i, b2f, b2c, b2o; |
| |
| if (has_bias) { |
| b1i = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 0 * hsz); |
| b1f = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 1 * hsz); |
| b1c = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 2 * hsz); |
| b1o = DEVICE_BIAS_GET(bias1, linearIndex % hsz + 3 * hsz); |
| |
| b2i = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 0 * hsz); |
| b2f = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 1 * hsz); |
| b2c = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 2 * hsz); |
| b2o = DEVICE_BIAS_GET(bias2, linearIndex % hsz + 3 * hsz); |
| } else { |
| #ifndef THC_REAL_IS_HALF |
| b1i = 0.0; b1f = 0.0; b1c = 0.0; b1o = 0.0; |
| b2i = 0.0; b2f = 0.0; b2c = 0.0; b2o = 0.0; |
| #else |
| b1i = F2H(0.0); b1f = F2H(0.0); b1c = F2H(0.0); b1o = F2H(0.0); |
| b2i = F2H(0.0); b2f = F2H(0.0); b2c = F2H(0.0); b2o = F2H(0.0); |
| #endif |
| } |
| |
| accscalar_t ig, fg, cg, og; |
| accscalar_t f_hy, f_cy; |
| |
| ig = sigmoid(H2F(iig) + H2F(hig) + H2F(b1i) + H2F(b2i)); |
| fg = sigmoid(H2F(ifg) + H2F(hfg) + H2F(b1f) + H2F(b2f)); |
| cg = ::tanh(H2F(icg) + H2F(hcg) + H2F(b1c) + H2F(b2c)); |
| og = sigmoid(H2F(iog) + H2F(hog) + H2F(b1o) + H2F(b2o)); |
| |
| f_cy = (fg * H2F(cx)) + (ig * cg); |
| f_hy = og * ::tanh(f_cy); |
| |
| *hy = F2H(f_hy); |
| *cy = F2H(f_cy); |
| |
| //SAVE FOR BACKWARDS |
| //Also need cy and cx but can be saved easily in python |
| *wig = F2H(ig); |
| *wfg = F2H(fg); |
| *wcg = F2H(cg); |
| *wog = F2H(og); |
| } |
| } |
| |
| template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind> |
| #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) |
| C10_LAUNCH_BOUNDS_2(512, 4) |
| #endif |
| __global__ void lstm_cell_backward( |
| TensorInfo<scalar_t, index_type> storage, |
| TensorInfo<scalar_t, index_type> gradInGates, |
| TensorInfo<scalar_t, index_type> _cx, |
| TensorInfo<scalar_t, index_type> _cy, |
| TensorInfo<scalar_t, index_type> gradoutput, |
| TensorInfo<scalar_t, index_type> gradoutputcell, |
| TensorInfo<scalar_t, index_type> gradInputCx, |
| index_type hsz, |
| index_type totalElements) { |
| bool has_gradoutput = gradoutput.data != nullptr; |
| bool has_gradoutputcell = gradoutputcell.data != nullptr; |
| for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x; |
| linearIndex < totalElements; |
| linearIndex += gridDim.x * blockDim.x) { |
| index_type offset = (linearIndex/hsz)*4*hsz+linearIndex%hsz; |
| |
| scalar_t ig = DEVICE_LINEAR_GET(storage, offset+0*hsz); |
| scalar_t fg = DEVICE_LINEAR_GET(storage, offset+1*hsz); |
| scalar_t cg = DEVICE_LINEAR_GET(storage, offset+2*hsz); |
| scalar_t og = DEVICE_LINEAR_GET(storage, offset+3*hsz); |
| |
| scalar_t* ih = &DEVICE_LINEAR_GET(gradInGates, offset+0*hsz); |
| scalar_t* fh = &DEVICE_LINEAR_GET(gradInGates, offset+1*hsz); |
| scalar_t* ch = &DEVICE_LINEAR_GET(gradInGates, offset+2*hsz); |
| scalar_t* oh = &DEVICE_LINEAR_GET(gradInGates, offset+3*hsz); |
| |
| //will return hidden grads here |
| scalar_t cx = DEVICE_LINEAR_GET(_cx, linearIndex); |
| scalar_t cy = DEVICE_LINEAR_GET(_cy, linearIndex); |
| |
| scalar_t* gi = &DEVICE_LINEAR_GET(gradInputCx, linearIndex); |
| |
| accscalar_t go = has_gradoutput ? H2F(DEVICE_LINEAR_GET(gradoutput, linearIndex)) : 0.f; |
| accscalar_t goc = has_gradoutputcell ? H2F(DEVICE_LINEAR_GET(gradoutputcell, linearIndex)) : 0.f; |
| |
| accscalar_t gcx = ::tanh(H2F(cy)); |
| |
| accscalar_t gog = go * gcx; |
| gcx = go * H2F(og) * (1 - gcx*gcx) + goc; |
| |
| accscalar_t gig = gcx * H2F(cg); |
| accscalar_t gfg = gcx * H2F(cx); |
| accscalar_t gcg = gcx * H2F(ig); |
| |
| gcx = gcx * H2F(fg); |
| |
| gig = gig * (1-H2F(ig)) * H2F(ig); |
| gfg = gfg * (1-H2F(fg)) * H2F(fg); |
| gcg = gcg * (1-H2F(cg)*H2F(cg)); |
| gog = gog * (1-H2F(og)) * H2F(og); |
| |
| *ih = F2H(gig); |
| *fh = F2H(gfg); |
| *ch = F2H(gcg); |
| *oh = F2H(gog); |
| |
| *gi = F2H(gcx); |
| } |
| } |
| |
| template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind> |
| #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) |
| C10_LAUNCH_BOUNDS_2(512, 4) |
| #endif |
| __global__ void gru_cell_forward( |
| TensorInfo<scalar_t, index_type> Input, |
| TensorInfo<scalar_t, index_type> Hidden, |
| TensorInfo<scalar_t, index_type> Bias1, |
| TensorInfo<scalar_t, index_type> Bias2, |
| TensorInfo<scalar_t, index_type> _hx, |
| TensorInfo<scalar_t, index_type> _hy, |
| TensorInfo<scalar_t, index_type> storage, |
| index_type hsz, |
| index_type totalElements) { |
| bool has_bias = Bias1.data != nullptr; |
| for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x; |
| linearIndex < totalElements; |
| linearIndex += gridDim.x * blockDim.x) { |
| index_type offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz; |
| |
| scalar_t ir = DEVICE_LINEAR_GET(Input, offset+0*hsz); |
| scalar_t ii = DEVICE_LINEAR_GET(Input, offset+1*hsz); |
| scalar_t in = DEVICE_LINEAR_GET(Input, offset+2*hsz); |
| scalar_t hr = DEVICE_LINEAR_GET(Hidden,offset+0*hsz); |
| scalar_t hi = DEVICE_LINEAR_GET(Hidden,offset+1*hsz); |
| scalar_t hn = DEVICE_LINEAR_GET(Hidden, offset+2*hsz); |
| |
| scalar_t hx = DEVICE_LINEAR_GET(_hx, linearIndex); |
| scalar_t* hy = &DEVICE_LINEAR_GET(_hy, linearIndex); |
| |
| scalar_t b1r, b1i, b1n, b2r, b2i, b2n; |
| |
| if (has_bias) { |
| b1r = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+0*hsz); |
| b1i = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+1*hsz); |
| b1n = DEVICE_BIAS_GET(Bias1, linearIndex%hsz+2*hsz); |
| |
| b2r = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+0*hsz); |
| b2i = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+1*hsz); |
| b2n = DEVICE_BIAS_GET(Bias2, linearIndex%hsz+2*hsz); |
| } else { |
| #ifndef THC_REAL_IS_HALF |
| b1r = 0.0; b1i = 0.0; b1n = 0.0; |
| b2r = 0.0; b2i = 0.0; b2n = 0.0; |
| #else |
| b1r = F2H(0.0); b1i = F2H(0.0); b1n = F2H(0.0); |
| b2r = F2H(0.0); b2i = F2H(0.0); b2n = F2H(0.0); |
| #endif |
| } |
| |
| offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz; |
| |
| accscalar_t rg, ig, ng; |
| |
| rg = sigmoid(H2F(ir) + H2F(hr) + H2F(b1r) + H2F(b2r)); |
| ig = sigmoid(H2F(ii) + H2F(hi) + H2F(b1i) + H2F(b2i)); |
| |
| ng = H2F(in) + H2F(b1n) + rg*( H2F(hn)+H2F(b2n) ); |
| ng = ::tanh(ng); |
| *hy = F2H( ng + ig * ( H2F(hx)-ng ) ); |
| |
| //SAVE FOR BACKWARDS |
| DEVICE_LINEAR_GET(storage, offset+0*hsz) = F2H(rg); |
| DEVICE_LINEAR_GET(storage, offset+1*hsz) = F2H(ig); |
| DEVICE_LINEAR_GET(storage, offset+2*hsz) = F2H(ng); |
| DEVICE_LINEAR_GET(storage, offset+3*hsz) = hx; |
| DEVICE_LINEAR_GET(storage, offset+4*hsz) = F2H(H2F(hn) + H2F(b2n)); |
| } |
| } |
| |
| template <typename scalar_t, typename accscalar_t, typename index_type, int indexing_kind> |
| #if __CUDA_ARCH__ >= 350 || defined(USE_ROCM) |
| C10_LAUNCH_BOUNDS_2(512, 4) |
| #endif |
| __global__ void gru_cell_backward( |
| TensorInfo<scalar_t, index_type> gradInInput, |
| TensorInfo<scalar_t, index_type> gradInHidden, |
| TensorInfo<scalar_t, index_type> gradOutput, |
| TensorInfo<scalar_t, index_type> gradInputHx, |
| TensorInfo<scalar_t, index_type> storage, |
| index_type hsz, |
| index_type totalElements) { |
| for (index_type linearIndex = blockIdx.x * blockDim.x + threadIdx.x; |
| linearIndex < totalElements; |
| linearIndex += gridDim.x * blockDim.x) { |
| index_type offset = (linearIndex/hsz)*5*hsz+linearIndex%hsz; |
| |
| scalar_t rg = DEVICE_LINEAR_GET(storage, offset+0*hsz); |
| scalar_t ig = DEVICE_LINEAR_GET(storage, offset+1*hsz); |
| scalar_t ng = DEVICE_LINEAR_GET(storage, offset+2*hsz); |
| scalar_t hx = DEVICE_LINEAR_GET(storage, offset+3*hsz); |
| scalar_t hn = DEVICE_LINEAR_GET(storage, offset+4*hsz); |
| |
| scalar_t go = DEVICE_LINEAR_GET(gradOutput, linearIndex); |
| |
| offset = (linearIndex/hsz)*3*hsz+linearIndex%hsz; |
| |
| accscalar_t gig = H2F(go)*( H2F(hx)-H2F(ng) )*( 1-H2F(ig) )*H2F(ig); |
| accscalar_t ghx = H2F(go)*H2F(ig); |
| accscalar_t gin = H2F(go)*( 1-H2F(ig) )*( 1-H2F(ng)*H2F(ng) ); |
| accscalar_t ghn = gin * H2F(rg); |
| accscalar_t grg = gin *H2F(hn)*( 1-H2F(rg) )*H2F(rg); |
| |
| DEVICE_LINEAR_GET(gradInInput, offset+0*hsz) = F2H(grg); |
| DEVICE_LINEAR_GET(gradInInput, offset+1*hsz) = F2H(gig); |
| DEVICE_LINEAR_GET(gradInInput, offset+2*hsz) = F2H(gin); |
| |
| DEVICE_LINEAR_GET(gradInHidden, offset+0*hsz) = F2H(grg); |
| DEVICE_LINEAR_GET(gradInHidden, offset+1*hsz) = F2H(gig); |
| DEVICE_LINEAR_GET(gradInHidden, offset+2*hsz) = F2H(ghn); |
| DEVICE_LINEAR_GET(gradInputHx, linearIndex) = F2H(ghx); |
| } |
| } |
| |
| #undef DEVICE_LINEAR_GET |
| #undef DEVICE_BIAS_GET |
| #undef H2F |
| #undef F2H |
| |
| } // namespace kernel |
| |
| template<typename scalar_t, typename index_type> |
| void lstm_forward_impl(const Tensor& input_gates, const Tensor& hidden_gates, |
| const Tensor& input_bias, const Tensor& hidden_bias, |
| const Tensor& cx, |
| const Tensor& hy, const Tensor& cy, const Tensor& workspace) { |
| using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>; |
| |
| dim3 block, grid; |
| int64_t numel = cx.numel(); |
| if (numel == 0) return; |
| getLaunchConfig(&block, &grid, numel); |
| |
| auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates); |
| auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates); |
| auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias); |
| auto hidden_biasI = tryGetTensorInfo<scalar_t, index_type>(hidden_bias); |
| auto cxI = getTensorInfo<scalar_t, index_type>(cx); |
| auto hyI = getTensorInfo<scalar_t, index_type>(hy); |
| auto cyI = getTensorInfo<scalar_t, index_type>(cy); |
| auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace); |
| index_type hidden_size = cxI.sizes[cxI.dims-1]; |
| |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| if (allContiguous({input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace})) { |
| collapseDims(input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI); |
| kernel::lstm_cell_forward<scalar_t, accscalar_t, index_type, 1> |
| <<<grid, block, 0, stream>>> |
| (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } else { |
| kernel::lstm_cell_forward<scalar_t, accscalar_t, index_type, 2> |
| <<<grid, block, 0, stream>>> |
| (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, cxI, hyI, cyI, workspaceI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } |
| |
| template<typename scalar_t, typename index_type> |
| void lstm_backward_impl(const Tensor& grad_hy, const Tensor& grad_cy, |
| const Tensor& cx, const Tensor& cy, |
| const Tensor& workspace, |
| const Tensor& grad_gates, const Tensor& grad_cx) { |
| using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>; |
| |
| dim3 block, grid; |
| int64_t numel = cx.numel(); |
| getLaunchConfig(&block, &grid, numel); |
| if (numel == 0) return; |
| |
| auto grad_hyI = tryGetTensorInfo<scalar_t, index_type>(grad_hy); |
| auto grad_cyI = tryGetTensorInfo<scalar_t, index_type>(grad_cy); |
| auto cxI = getTensorInfo<scalar_t, index_type>(cx); |
| auto cyI = getTensorInfo<scalar_t, index_type>(cy); |
| auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace); |
| auto grad_gatesI = getTensorInfo<scalar_t, index_type>(grad_gates); |
| auto grad_cxI = getTensorInfo<scalar_t, index_type>(grad_cx); |
| index_type hidden_size = cxI.sizes[cxI.dims-1]; |
| |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| if (allContiguous({grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx})) { |
| collapseDims(grad_hyI, grad_cyI, cxI, cyI, workspaceI, grad_gatesI, grad_cxI); |
| kernel::lstm_cell_backward<scalar_t, accscalar_t, index_type, 1> |
| <<<grid, block, 0, stream>>> |
| (workspaceI, grad_gatesI, cxI, cyI, grad_hyI, grad_cyI, grad_cxI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } else { |
| kernel::lstm_cell_backward<scalar_t, accscalar_t, index_type, 2> |
| <<<grid, block, 0, stream>>> |
| (workspaceI, grad_gatesI, cxI, cyI, grad_hyI, grad_cyI, grad_cxI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } |
| |
| template<typename scalar_t, typename index_type> |
| void gru_forward_impl(const Tensor& input_gates, const Tensor& hidden_gates, |
| const Tensor& input_bias, const Tensor& hidden_bias, |
| const Tensor& hx, |
| const Tensor& hy, const Tensor& workspace) { |
| using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>; |
| |
| dim3 block, grid; |
| int64_t numel = hx.numel(); |
| if (numel == 0) return; |
| getLaunchConfig(&block, &grid, numel); |
| |
| auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates); |
| auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates); |
| auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias); |
| auto hidden_biasI = tryGetTensorInfo<scalar_t, index_type>(hidden_bias); |
| auto hxI = getTensorInfo<scalar_t, index_type>(hx); |
| auto hyI = getTensorInfo<scalar_t, index_type>(hy); |
| auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace); |
| index_type hidden_size = hxI.sizes[hxI.dims-1]; |
| |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| if (allContiguous({input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace})) { |
| collapseDims(input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI); |
| kernel::gru_cell_forward<scalar_t, accscalar_t, index_type, 1> |
| <<<grid, block, 0, stream>>> |
| (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } else { |
| kernel::gru_cell_forward<scalar_t, accscalar_t, index_type, 2> |
| <<<grid, block, 0, stream>>> |
| (input_gatesI, hidden_gatesI, input_biasI, hidden_biasI, hxI, hyI, workspaceI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } |
| |
| template<typename scalar_t, typename index_type> |
| void gru_backward_impl(const Tensor& grad_hy, const Tensor& workspace, |
| const Tensor& grad_input_gates, const Tensor& grad_hidden_gates, const Tensor& grad_hx) { |
| using accscalar_t = acc_type<scalar_t, /*is_cuda=*/true>; |
| |
| dim3 block, grid; |
| int64_t numel = grad_hy.numel(); |
| if (numel == 0) return; |
| getLaunchConfig(&block, &grid, numel); |
| |
| auto grad_hyI = getTensorInfo<scalar_t, index_type>(grad_hy); |
| auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace); |
| auto grad_input_gatesI = getTensorInfo<scalar_t, index_type>(grad_input_gates); |
| auto grad_hidden_gatesI = getTensorInfo<scalar_t, index_type>(grad_hidden_gates); |
| auto grad_hxI = getTensorInfo<scalar_t, index_type>(grad_hx); |
| index_type hidden_size = grad_hyI.sizes[grad_hyI.dims-1]; |
| |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| if (allContiguous({grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx})) { |
| collapseDims(grad_hyI, workspaceI, grad_input_gatesI, grad_hidden_gatesI, grad_hxI); |
| kernel::gru_cell_backward<scalar_t, accscalar_t, index_type, 1> |
| <<<grid, block, 0, stream>>> |
| (grad_input_gatesI, grad_hidden_gatesI, grad_hyI, grad_hxI, workspaceI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } else { |
| kernel::gru_cell_backward<scalar_t, accscalar_t, index_type, 2> |
| <<<grid, block, 0, stream>>> |
| (grad_input_gatesI, grad_hidden_gatesI, grad_hyI, grad_hxI, workspaceI, hidden_size, numel); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } |
| |
| } // anonymous namespace |
| |
| // Note [64-bit index math check elision] |
| // It's enough to perform the check for 64-bit math on the largest tensor only. |
| // If 32-bit is enough for it, it will suffice for all other tensors too, and we |
| // can save some work using this trick. |
| |
| std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_cuda( |
| const Tensor& input_gates, const Tensor& hidden_gates, |
| const Tensor& cx, const c10::optional<Tensor>& input_bias_opt, const c10::optional<Tensor>& hidden_bias_opt) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt); |
| const Tensor& input_bias = *input_bias_maybe_owned; |
| const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();}); |
| |
| checkSizes("_thnn_fused_lstm_cell_cuda", |
| {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2}, |
| {input_bias, "input_bias", 3}, {hidden_bias, "hidden_bias", 4}, |
| /*factor=*/4, {cx, "prev_hidden", 5}); |
| |
| auto workspace = at::empty_like(input_gates, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto hy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto cy = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| input_gates.scalar_type(), |
| "_thnn_fused_lstm_cell_cuda", |
| [&] { |
| if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] |
| lstm_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace); |
| } else { |
| lstm_forward_impl<scalar_t, int64_t>(input_gates, hidden_gates, input_bias, hidden_bias, cx, hy, cy, workspace); |
| } |
| }); |
| return std::make_tuple(std::move(hy), std::move(cy), std::move(workspace)); |
| } |
| |
| void checkLSTMBackwardSizes(const TensorArg& grad_hy, const TensorArg& grad_cy, |
| const TensorArg& cx, const TensorArg& cy, |
| const TensorArg& workspace) { |
| CheckedFrom c = "fused_lstm_cell_backward"; |
| const TensorArg& defined_grad = grad_hy->defined() ? grad_hy : grad_cy; |
| checkDim(c, defined_grad, 2); |
| auto exp_size = defined_grad->sizes(); |
| if (grad_hy->defined()) { |
| checkSize(c, grad_hy, exp_size); |
| } |
| if (grad_cy->defined()) { |
| checkSize(c, grad_cy, exp_size); |
| } |
| checkSize(c, cx, exp_size); |
| checkSize(c, cy, exp_size); |
| checkDim(c, workspace, 2); |
| checkNumel(c, workspace, exp_size[0] * exp_size[1] * 4); |
| } |
| |
| std::tuple<Tensor, Tensor, Tensor> _thnn_fused_lstm_cell_backward_impl_cuda( const c10::optional<Tensor>& grad_hy_opt, const c10::optional<Tensor>& grad_cy_opt, |
| const Tensor& cx, const Tensor& cy, |
| const Tensor& workspace, bool has_bias) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> grad_hy_maybe_owned = at::borrow_from_optional_tensor(grad_hy_opt); |
| const Tensor& grad_hy = *grad_hy_maybe_owned; |
| const Tensor& grad_cy = c10::value_or_else(grad_cy_opt, [] {return Tensor();}); |
| |
| if (!grad_hy.defined() && !grad_cy.defined()) { |
| return std::tuple<Tensor, Tensor, Tensor>(); |
| } |
| checkLSTMBackwardSizes({grad_hy, "grad_hy", 1}, {grad_cy, "grad_cy", 2}, |
| {cx, "cx", 3}, {cy, "cy", 4}, |
| {workspace, "workspace", 5}); |
| |
| auto grad_gates = at::empty_like(workspace, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| auto grad_cx = at::empty_like(cx, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| workspace.scalar_type(), |
| "_thnn_fused_lstm_cell_cuda_backward", |
| [&] { |
| if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] |
| lstm_backward_impl<scalar_t, int32_t>(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx); |
| } else { |
| lstm_backward_impl<scalar_t, int64_t>(grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx); |
| } |
| }); |
| |
| auto grad_bias = has_bias ? grad_gates.sum(0, /*keepdim=*/false) : at::Tensor{}; |
| return std::make_tuple(std::move(grad_gates), std::move(grad_cx), std::move(grad_bias)); |
| } |
| |
| static constexpr int64_t GRU_WORKSPACE_MULTIPLIER = 5; |
| |
| std::tuple<Tensor, Tensor> _thnn_fused_gru_cell_cuda( |
| const Tensor& input_gates, const Tensor& hidden_gates, |
| const Tensor& hx, const c10::optional<Tensor>& input_bias_opt, const c10::optional<Tensor>& hidden_bias_opt) { |
| // See [Note: hacky wrapper removal for optional tensor] |
| c10::MaybeOwned<Tensor> input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt); |
| const Tensor& input_bias = *input_bias_maybe_owned; |
| const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();}); |
| |
| checkSizes("_thnn_fused_gru_cell_cuda", |
| {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2}, |
| {input_bias, "input_bias", 3}, {hidden_bias, "hidden_bias", 4}, |
| /*factor=*/3, {hx, "prev_hidden", 5}); |
| |
| auto workspace = at::empty({hx.size(0), hx.size(1) * GRU_WORKSPACE_MULTIPLIER}, hx.options()); |
| auto hy = at::empty_like(hx, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| input_gates.scalar_type(), |
| "_thnn_fused_gru_cell_cuda", |
| [&] { |
| if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] |
| gru_forward_impl<scalar_t, int32_t>(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace); |
| } else { |
| gru_forward_impl<scalar_t, int64_t>(input_gates, hidden_gates, input_bias, hidden_bias, hx, hy, workspace); |
| } |
| }); |
| return std::make_tuple(std::move(hy), std::move(workspace)); |
| } |
| |
| void checkGRUBackwardSizes(const TensorArg& grad_hy, const TensorArg& workspace) { |
| CheckedFrom c = "fused_gru_cell_backward"; |
| checkDim(c, grad_hy, 2); |
| checkSize(c, workspace, {grad_hy->size(0), grad_hy->size(1) * GRU_WORKSPACE_MULTIPLIER}); |
| } |
| |
| std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _thnn_fused_gru_cell_backward_cuda( |
| const Tensor& grad_hy, const Tensor& workspace, bool has_bias) { |
| checkGRUBackwardSizes({grad_hy, "grad_hy", 1}, {workspace, "workspace", 2}); |
| |
| int64_t hidden_size = workspace.size(1) / GRU_WORKSPACE_MULTIPLIER; |
| auto grad_input_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options()); |
| auto grad_hidden_gates = at::empty({workspace.size(0), hidden_size * 3}, workspace.options()); |
| auto grad_hx = at::empty_like(grad_hy, LEGACY_CONTIGUOUS_MEMORY_FORMAT); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| at::ScalarType::Half, |
| at::ScalarType::BFloat16, |
| grad_hy.scalar_type(), |
| "_thnn_fused_gru_cell_cuda_backward", |
| [&] { |
| if (canUse32BitIndexMath(workspace)) { // See Note [64-bit index math check elision] |
| gru_backward_impl<scalar_t, int32_t>(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx); |
| } else { |
| gru_backward_impl<scalar_t, int64_t>(grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx); |
| } |
| }); |
| |
| at::Tensor grad_input_bias, grad_hidden_bias; |
| if (has_bias) { |
| grad_input_bias = grad_input_gates.sum(0, /*keepdim=*/false); |
| grad_hidden_bias = grad_hidden_gates.sum(0, /*keepdim=*/false); |
| } |
| |
| return std::make_tuple( |
| std::move(grad_input_gates), |
| std::move(grad_hidden_gates), |
| std::move(grad_hx), |
| std::move(grad_input_bias), |
| std::move(grad_hidden_bias) |
| ); |
| } |
| |
| } // namespace at::native |