| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/core/Tensor.h> |
| #include <ATen/cuda/CUDAContext.h> |
| #include <ATen/Dispatch.h> |
| #include <ATen/TensorUtils.h> |
| #include <ATen/TensorOperators.h> |
| #include <ATen/WrapDimUtils.h> |
| #include <c10/macros/Macros.h> |
| |
| #include <ATen/AccumulateType.h> |
| #include <ATen/cuda/NumericLimits.cuh> |
| #include <type_traits> |
| |
| #include <ATen/native/cuda/Loops.cuh> |
| #include <ATen/native/cuda/MemoryAccess.cuh> |
| #include <ATen/native/cuda/PersistentSoftmax.cuh> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/_masked_softmax_native.h> |
| #include <ATen/ops/_log_softmax_native.h> |
| #include <ATen/ops/_log_softmax_backward_data_native.h> |
| #include <ATen/ops/_softmax_native.h> |
| #include <ATen/ops/_softmax_backward_data_native.h> |
| #include <ATen/ops/softmax.h> |
| #include <ATen/ops/_softmax_backward_data.h> |
| #endif |
| |
| namespace at { |
| namespace native { |
| |
| namespace { |
| |
| constexpr int ALIGN_BYTES = 16; |
| |
| template<typename T, typename AccumT, typename OutT> |
| struct LogSoftMaxForwardEpilogue { |
| __device__ __forceinline__ LogSoftMaxForwardEpilogue(AccumT max_input, AccumT sum) |
| : max_input(max_input), logsum(std::log(sum)) {} |
| |
| __device__ __forceinline__ OutT operator()(T input) const { |
| return static_cast<OutT>(input - max_input - logsum); |
| } |
| |
| const AccumT max_input; |
| const AccumT logsum; |
| }; |
| |
| template<typename T, typename AccumT, typename OutT> |
| struct LogSoftMaxBackwardEpilogue { |
| __device__ __forceinline__ LogSoftMaxBackwardEpilogue(AccumT sum) |
| : sum(sum) {} |
| |
| __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { |
| return static_cast<T>(gradOutput - std::exp(static_cast<AccumT>(output)) * sum); |
| } |
| |
| const AccumT sum; |
| }; |
| |
| template<typename T, typename AccumT, typename OutT> |
| struct SoftMaxForwardEpilogue { |
| __device__ __forceinline__ SoftMaxForwardEpilogue(AccumT max_input, AccumT sum) |
| : max_input(max_input) |
| , sum(sum) {} |
| |
| __device__ __forceinline__ OutT operator()(T input) const { |
| return static_cast<OutT>(std::exp(input - max_input) / sum); |
| } |
| |
| const AccumT max_input; |
| const AccumT sum; |
| }; |
| |
| template<typename T, typename AccumT, typename OutT> |
| struct SoftMaxBackwardEpilogue { |
| __device__ __forceinline__ SoftMaxBackwardEpilogue(AccumT sum) |
| : sum(sum) {} |
| |
| // XXX: gradOutput that we get here is really gradOutput * output |
| // Look for cmul in SoftMax_updateGradInput |
| __device__ __forceinline__ T operator()(OutT gradOutput, OutT output) const { |
| return static_cast<T>(gradOutput - output * sum); |
| } |
| |
| const AccumT sum; |
| }; |
| |
| |
| |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Spatial kernel (fast with large inner_size and small dim_size) |
| //////////////////////////////////////////////////////////////////////////////// |
| // Let's assume that our input has been flattened to have only three dimension: |
| // outer x dim x inner |
| // The spatial algorithm tries to parallelize along all of them. |
| // Within a 2d block threadIdx.y parallelizes over dim slices, and threads that |
| // share it will speed up reductions over dim (along axis x). |
| // The 2d grid is used to parallelize inner dimension over y axis and outer over x. |
| inline dim3 SpatialSoftMax_getGridSize( |
| dim3 block, uint32_t max_active_blocks, |
| uint64_t outer_size, uint64_t inner_size) { |
| // First, tile as many blocks as we can over the y axis |
| uint32_t inner_blocks = (inner_size + block.y - 1) / block.y; |
| if (inner_blocks > max_active_blocks) |
| inner_blocks = max_active_blocks; |
| // Fill the x axis with as many blocks as we can fit (a little more is ok too) |
| uint32_t outer_blocks = (max_active_blocks + inner_blocks - 1) / inner_blocks; |
| if (outer_blocks > outer_size) |
| outer_blocks = outer_size; |
| return dim3(outer_blocks, inner_blocks); |
| } |
| |
| const int max_threads = 1024; |
| |
| inline dim3 SpatialSoftMax_getBlockSize( |
| uint64_t dim_size, uint64_t inner_size) { |
| uint32_t inner_threads = inner_size; |
| inner_threads = std::min(inner_threads, static_cast<uint32_t>(max_threads)); |
| uint32_t dim_threads = 1; |
| if (inner_threads <= 64 && dim_size >= 64) { |
| while (inner_threads * dim_threads <= max_threads && dim_threads <= dim_size) |
| dim_threads *= 2; |
| dim_threads /= 2; |
| } |
| return dim3(dim_threads, inner_threads); |
| } |
| |
| |
| template<typename accscalar_t, typename Kernel> |
| void SpatialSoftMax_getLaunchSizes( |
| Kernel k, |
| uint64_t outer_size, uint64_t dim_size, uint64_t inner_size, |
| dim3& grid, dim3& block, uint32_t& smem_size) { |
| block = SpatialSoftMax_getBlockSize(dim_size, inner_size); |
| uint32_t block_threads = block.x * block.y; |
| smem_size = block.x == 1 ? 0 : block_threads * sizeof(accscalar_t); |
| int max_active_blocks; |
| #if defined(USE_ROCM) && TORCH_HIP_VERSION < 305 |
| // HIP function signature is not compatible yet. |
| uint32_t max_blocks; |
| cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks, |
| k, block_threads, smem_size); |
| max_active_blocks = max_blocks; |
| #else |
| cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, |
| k, block_threads, smem_size); |
| #endif |
| max_active_blocks *= at::cuda::getCurrentDeviceProperties()->multiProcessorCount; |
| grid = SpatialSoftMax_getGridSize(block, max_active_blocks, outer_size, inner_size); |
| } |
| |
| inline dim3 SoftMax_getBlockSize(int ILP, uint64_t dim_size) { |
| uint64_t block_size = 1; |
| uint64_t max_block_size = std::min(dim_size / ILP, static_cast<uint64_t>(max_threads)); |
| |
| // In the vectorized case we want to trade off allowing more of the buffers to be accessed |
| // in a vectorized way against wanting a larger block size to get better utilisation. |
| // In general with ILP you can have (ILP-1)/ILP of the buffer accessed vectorised, at the risk |
| // of having a very small block size. We choose to keep >= 1/2 of the buffer vectorised while |
| // allowing a larger block size. |
| if (ILP > 1) { |
| max_block_size /= 2; |
| } |
| |
| while (block_size < (max_block_size)) block_size *= 2; |
| // Launch at least a single warp - the kernel assumes that. |
| block_size = std::max(block_size, static_cast<uint64_t>(at::cuda::warp_size())); |
| return dim3(block_size); |
| } |
| |
| template<typename T> |
| struct Add { |
| __device__ __forceinline__ T operator()(T a, T b) const { |
| return a + b; |
| } |
| }; |
| |
| template<typename T> |
| struct Max { |
| __device__ __forceinline__ T operator()(T a, T b) const { |
| return a < b ? b : a; |
| } |
| }; |
| |
| // Note that it's not a complete block-wide reduction. |
| // Only threads that share threadIdx.y reduce values. |
| template<typename T, template<typename> class ReduceOp> |
| __forceinline__ __device__ |
| T spatialBlockReduceX(T *shared, T val) { |
| ReduceOp<T> r; |
| shared += threadIdx.y * blockDim.x; |
| |
| __syncthreads(); |
| |
| shared[threadIdx.x] = val; |
| |
| // NOTE: loop starts with __syncthreads() |
| int offset = blockDim.x / 2; |
| while (offset > 0) { |
| __syncthreads(); |
| if (threadIdx.x < offset) |
| shared[threadIdx.x] = r(shared[threadIdx.x], shared[threadIdx.x + offset]); |
| offset /= 2; |
| } |
| |
| __syncthreads(); |
| |
| return shared[0]; |
| } |
| |
| template <typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> |
| __global__ void cunn_SpatialSoftMaxForward( |
| outscalar_t *output, scalar_t *input, |
| uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) |
| { |
| extern __shared__ unsigned char smem[]; |
| auto sdata = reinterpret_cast<accscalar_t*>(smem); |
| const uint32_t outer_stride = inner_size * dim_size; |
| const uint32_t dim_stride = inner_size; |
| |
| for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { |
| const uint32_t outer_offset = outer_index * outer_stride; |
| for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) { |
| const uint32_t data_offset = outer_offset + inner_index; |
| //////////////////////////////////////////////////////////// |
| // These two blocks are really equivalent, but specializing on |
| // blockDim.x == 1 makes the kernel faster when it's unused. |
| // I didn't want to thread an extra template parameter, and nvcc |
| // seems to be smart enough to hoist the if outside of the loops. |
| //////////////////////////////////////////////////////////// |
| |
| if (blockDim.x > 1) { |
| accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest(); |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { |
| const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]); |
| max_input = Max<accscalar_t>()(max_input, value); |
| } |
| max_input = spatialBlockReduceX<accscalar_t, Max>(sdata,max_input); |
| |
| accscalar_t sum = 0; |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) |
| sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride]) |
| - max_input); |
| sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum); |
| |
| Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum); |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) |
| output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); |
| } else { |
| accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest(); |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { |
| const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]); |
| max_input = Max<accscalar_t>()(max_input, value); |
| } |
| accscalar_t sum = 0; |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) |
| sum += std::exp(static_cast<accscalar_t>(input[data_offset + d * dim_stride]) |
| - max_input); |
| Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_input, sum); |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) |
| output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]); |
| } |
| } |
| } |
| } |
| |
| |
| |
| template <typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> |
| __global__ void cunn_SpatialSoftMaxBackward( |
| scalar_t *gradInput, outscalar_t *output, outscalar_t *gradOutput, |
| uint32_t outer_size, uint32_t dim_size, uint32_t inner_size) |
| { |
| extern __shared__ unsigned char smem[]; |
| auto sdata = reinterpret_cast<accscalar_t*>(smem); |
| const uint32_t outer_stride = inner_size * dim_size; |
| const uint32_t dim_stride = inner_size; |
| |
| for (uint32_t outer_index = blockIdx.x; outer_index < outer_size; outer_index += gridDim.x) { |
| const uint32_t outer_offset = outer_index * outer_stride; |
| for (uint32_t inner_index = blockIdx.y * blockDim.y + threadIdx.y; inner_index < inner_size; inner_index += blockDim.y * gridDim.y) { |
| const uint32_t data_offset = outer_offset + inner_index; |
| // See the comment in forward kernel |
| if (blockDim.x > 1) { |
| accscalar_t sum = 0; |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) |
| sum += gradOutput[data_offset + d * dim_stride]; |
| sum = spatialBlockReduceX<accscalar_t, Add>(sdata, sum); |
| |
| Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum); |
| for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) { |
| gradInput[data_offset + d * dim_stride] = |
| epilogue(gradOutput[data_offset + d * dim_stride], |
| output[data_offset + d * dim_stride]); |
| } |
| } else { |
| accscalar_t sum = 0; |
| for (uint32_t d = 0; d < dim_size; d++) |
| sum += gradOutput[data_offset + d * dim_stride]; |
| |
| Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum); |
| for (uint32_t d = 0; d < dim_size; d++) { |
| gradInput[data_offset + d * dim_stride] = |
| epilogue(gradOutput[data_offset + d * dim_stride], |
| output[data_offset + d * dim_stride]); |
| } |
| } |
| } |
| } |
| } |
| |
| |
| //////////////////////////////////////////////////////////////////////////////// |
| // Regular kernel (fast when dim_size is large; requires inner_size == 1) |
| //////////////////////////////////////////////////////////////////////////////// |
| |
| |
| template <typename T, typename AccumT> |
| struct MaxFloat |
| { |
| __device__ __forceinline__ AccumT operator()(AccumT max, T v) const { |
| return ::max(max, (AccumT)v); |
| } |
| }; |
| |
| template<typename T, typename AccumT> |
| struct AddFloat |
| { |
| __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { |
| return sum + v; |
| } |
| }; |
| |
| template<typename T, typename AccumT> |
| struct SumExpFloat |
| { |
| __device__ __forceinline__ SumExpFloat(AccumT v) |
| : max_k(v) {} |
| |
| __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { |
| return sum + std::exp(v - max_k); |
| } |
| |
| const AccumT max_k; |
| }; |
| |
| template <template<typename> class Reduction, typename AccumT> |
| __device__ __forceinline__ AccumT |
| blockReduce(AccumT* smem, AccumT val, |
| const Reduction<AccumT>& r, |
| AccumT defaultVal) |
| { |
| // To avoid RaW races from chaining blockReduce calls together, we need a sync here |
| __syncthreads(); |
| |
| smem[threadIdx.x] = val; |
| |
| __syncthreads(); |
| |
| AccumT warpVal = defaultVal; |
| |
| // First warp will perform per-warp reductions for the remaining warps |
| uint32_t mask = (((uint64_t)1) << (blockDim.x / C10_WARP_SIZE)) - 1; |
| if (threadIdx.x < C10_WARP_SIZE) { |
| int lane = threadIdx.x % C10_WARP_SIZE; |
| if (lane < blockDim.x / C10_WARP_SIZE) { |
| #pragma unroll |
| for (int i = 0; i < C10_WARP_SIZE; ++i) { |
| warpVal = r(warpVal, smem[lane * C10_WARP_SIZE + i]); |
| } |
| #if !defined(USE_ROCM) |
| __syncwarp(mask); |
| #endif |
| smem[lane] = warpVal; |
| } |
| } |
| |
| __syncthreads(); |
| |
| // First thread will perform a reduction of the above per-warp reductions |
| AccumT blockVal = defaultVal; |
| |
| if (threadIdx.x == 0) { |
| for (int i = 0; i < blockDim.x / C10_WARP_SIZE; ++i) { |
| blockVal = r(blockVal, smem[i]); |
| } |
| smem[0] = blockVal; |
| } |
| |
| // Sync and broadcast |
| __syncthreads(); |
| return smem[0]; |
| } |
| |
| template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT> |
| __device__ __forceinline__ AccumT |
| ilpReduce(int shift, |
| T* data, |
| int size, |
| const Reduction<T, AccumT>& r, |
| AccumT defaultVal) |
| { |
| using LoadT = at::native::memory::aligned_vector<T, ILP>; |
| AccumT threadVal = defaultVal; |
| int offset = threadIdx.x; |
| |
| // shift and do 1 |
| if(shift > 0){ |
| data -= shift; |
| size += shift; |
| if(threadIdx.x >= shift){ |
| threadVal = r(threadVal, data[offset]); |
| } |
| size -= blockDim.x; |
| data += blockDim.x; |
| } |
| int last = size % (ILP * blockDim.x); |
| |
| T v[ILP]; |
| LoadT* value = reinterpret_cast<LoadT*>(&v); |
| |
| for (; offset * ILP < (size - last); offset += blockDim.x) { |
| *value = reinterpret_cast<LoadT*>(data)[offset]; |
| |
| #pragma unroll |
| for (int j = 0; j < ILP; ++j) { |
| threadVal = r(threadVal, v[j]); |
| } |
| } |
| |
| offset = size - last + threadIdx.x; |
| // Epilogue |
| for (; offset < size; offset += blockDim.x) |
| threadVal = r(threadVal, data[offset]); |
| |
| return threadVal; |
| } |
| |
| /** |
| * This will apply the Epilogue with vectorized reads & writes when input & output have the same shift |
| */ |
| template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> |
| __device__ __forceinline__ void |
| WriteFpropResultsVectorized( |
| int size, |
| const int shift, |
| scalar_t *input, |
| outscalar_t *output, |
| Epilogue<scalar_t, accum_t, outscalar_t> epilogue) { |
| using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>; |
| using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>; |
| |
| int offset = threadIdx.x; |
| |
| // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later |
| if (shift > 0) { |
| input -= shift; |
| output -= shift; |
| size += shift; |
| |
| if (threadIdx.x >= shift) { |
| output[offset] = epilogue(input[offset]); |
| } |
| size -= blockDim.x; |
| input += blockDim.x; |
| output += blockDim.x; |
| } |
| |
| const int last = size % (ILP * blockDim.x); |
| |
| scalar_t in_v[ILP]; |
| LoadT* in_value = reinterpret_cast<LoadT*>(&in_v); |
| |
| outscalar_t out_v[ILP]; |
| StoreT* out_value = reinterpret_cast<StoreT*>(&out_v); |
| |
| for (; offset * ILP < (size - last); offset += blockDim.x) { |
| *in_value = reinterpret_cast<LoadT*>(input)[offset]; |
| |
| #pragma unroll |
| for (int j = 0; j < ILP; ++j) { |
| out_v[j] = epilogue(in_v[j]); |
| } |
| |
| reinterpret_cast<StoreT*>(output)[offset] = *out_value; |
| } |
| |
| offset = size - last + threadIdx.x; |
| // handle the tail |
| for (; offset < size; offset += blockDim.x) { |
| output[offset] = epilogue(input[offset]); |
| } |
| } |
| |
| template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> |
| __device__ __forceinline__ void |
| WriteBpropResultsVectorized( |
| int size, |
| const int shift, |
| scalar_t *gradInput, |
| outscalar_t *output, |
| outscalar_t *gradOutput, |
| Epilogue<scalar_t, accum_t, outscalar_t> epilogue) { |
| using gradInputT = at::native::memory::aligned_vector<scalar_t, ILP>; |
| using outputT = at::native::memory::aligned_vector<outscalar_t, ILP>; |
| |
| int offset = threadIdx.x; |
| |
| // if unaligned, do one value / thread and move on, guaranteeing aligned reads/writes later |
| if (shift > 0) { |
| gradInput -= shift; |
| output -= shift; |
| gradOutput -= shift; |
| size += shift; |
| |
| if (threadIdx.x >= shift) { |
| gradInput[offset] = epilogue(gradOutput[offset], output[offset]); |
| } |
| size -= blockDim.x; |
| gradInput += blockDim.x; |
| output += blockDim.x; |
| gradOutput += blockDim.x; |
| } |
| |
| const int last = size % (ILP * blockDim.x); |
| |
| scalar_t dX[ILP]; |
| gradInputT *dX_v = reinterpret_cast<gradInputT*>(&dX); |
| |
| outscalar_t Y[ILP]; |
| outputT *Y_v = reinterpret_cast<outputT*>(&Y); |
| |
| outscalar_t dY[ILP]; |
| outputT *dY_v = reinterpret_cast<outputT*>(&dY); |
| |
| for (; offset * ILP < (size - last); offset += blockDim.x) { |
| *Y_v = reinterpret_cast<outputT*>(output)[offset]; |
| *dY_v = reinterpret_cast<outputT*>(gradOutput)[offset]; |
| |
| #pragma unroll |
| for (int j = 0; j < ILP; ++j) { |
| dX[j] = epilogue(dY[j], Y[j]); |
| } |
| |
| reinterpret_cast<gradInputT*>(gradInput)[offset] = *dX_v; |
| } |
| |
| offset = size - last + threadIdx.x; |
| for (; offset < size; offset += blockDim.x) { |
| gradInput[offset] = epilogue(gradOutput[offset], output[offset]); |
| } |
| } |
| |
| /** |
| * This will apply the Epilogue with non-vectrorized reads & writes for the general case |
| */ |
| template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> |
| __device__ __forceinline__ void |
| WriteFpropResults( |
| int classes, |
| scalar_t *input, |
| outscalar_t *output, |
| Epilogue<scalar_t, accum_t, outscalar_t> epilogue) { |
| int offset = threadIdx.x; |
| |
| int last = classes % (ILP * blockDim.x); |
| |
| // Main bulk of loop with ILP |
| for (; offset < classes - last; offset += blockDim.x * ILP) { |
| scalar_t tmp[ILP]; |
| |
| #pragma unroll |
| for (int j = 0; j < ILP; ++j) { |
| tmp[j] = input[offset + j * blockDim.x]; |
| } |
| #pragma unroll |
| for (int j = 0; j < ILP; ++j) { |
| output[offset + j * blockDim.x] = epilogue(tmp[j]); |
| } |
| } |
| |
| // Remainder - no ILP |
| for (; offset < classes; offset += blockDim.x) { |
| output[offset] = epilogue(input[offset]); |
| } |
| } |
| |
| template <int ILP, typename scalar_t, typename accum_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> |
| __device__ __forceinline__ void |
| WriteBpropResults( |
| int classes, |
| scalar_t *gradInput, |
| outscalar_t *output, |
| outscalar_t *gradOutput, |
| Epilogue<scalar_t, accum_t, outscalar_t> epilogue) { |
| |
| int offset = threadIdx.x; |
| |
| int last = classes % (ILP * blockDim.x); |
| |
| for (; offset < classes - last; offset += blockDim.x * ILP) { |
| outscalar_t tmpOutput[ILP]; |
| outscalar_t tmpGradOutput[ILP]; |
| |
| #pragma unroll |
| for (int j = 0; j < ILP; ++j) { |
| tmpOutput[j] = output[offset + j * blockDim.x]; |
| tmpGradOutput[j] = gradOutput[offset + j * blockDim.x]; |
| } |
| |
| #pragma unroll |
| for (int j = 0; j < ILP; ++j) { |
| gradInput[offset + j * blockDim.x] = epilogue(tmpGradOutput[j], tmpOutput[j]); |
| } |
| } |
| |
| // Remainder - no ILP |
| for (; offset < classes; offset += blockDim.x) { |
| gradInput[offset] = epilogue(gradOutput[offset], output[offset]); |
| } |
| } |
| |
| template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template <typename, typename, typename> class Epilogue> |
| __global__ void |
| cunn_SoftMaxForward(outscalar_t *output, scalar_t *input, int classes) |
| { |
| extern __shared__ unsigned char smem[]; |
| auto sdata = reinterpret_cast<accscalar_t*>(smem); |
| |
| using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>; |
| using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>; |
| |
| // forward pointers to batch[blockIdx.x] |
| // each block handles a sample in the mini-batch |
| input += static_cast<int64_t>(blockIdx.x) * classes; |
| output += static_cast<int64_t>(blockIdx.x) * classes; |
| |
| const int shift = ((uint64_t)input) % ALIGN_BYTES / sizeof(scalar_t); |
| const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t); |
| |
| // find the max |
| accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>( |
| shift, input, classes, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max()); |
| accscalar_t max_k = blockReduce<Max, accscalar_t>( |
| sdata, threadMax, Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max()); |
| |
| // reduce all values |
| accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>( |
| shift, input, classes, SumExpFloat<scalar_t, accscalar_t>(max_k), static_cast<accscalar_t>(0)); |
| accscalar_t sumAll = blockReduce<Add, accscalar_t>( |
| sdata, threadExp, Add<accscalar_t>(), static_cast<accscalar_t>(0)); |
| |
| Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll); |
| |
| if (shift == output_shift) { |
| WriteFpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, shift, input, output, epilogue); |
| } else { |
| WriteFpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, input, output, epilogue); |
| } |
| } |
| |
| template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t, template<typename, typename, typename> class Epilogue> |
| __global__ void |
| cunn_SoftMaxBackward(scalar_t *gradInput, outscalar_t *output, outscalar_t *gradOutput, int classes) |
| { |
| using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>; |
| using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>; |
| |
| extern __shared__ unsigned char smem[]; |
| auto sdata = reinterpret_cast<accscalar_t*>(smem); |
| gradInput += static_cast<int64_t>(blockIdx.x) * classes; |
| output += static_cast<int64_t>(blockIdx.x) * classes; |
| gradOutput += static_cast<int64_t>(blockIdx.x) * classes; |
| |
| const int shift = ((uint64_t)gradInput) % ALIGN_BYTES / sizeof(scalar_t); |
| const int output_shift = ((uint64_t)output) % ALIGN_BYTES / sizeof(outscalar_t); |
| const int grad_output_shift = ((uint64_t)gradOutput) % ALIGN_BYTES / sizeof(outscalar_t); |
| |
| accscalar_t threadSum = ilpReduce<AddFloat, ILP, outscalar_t, accscalar_t>( |
| grad_output_shift, gradOutput, classes, AddFloat<outscalar_t, accscalar_t>(), accscalar_t(0)); |
| accscalar_t sum_k = blockReduce<Add, accscalar_t>( |
| sdata, threadSum, Add<accscalar_t>(), accscalar_t(0)); |
| |
| Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(sum_k); |
| |
| if (shift == output_shift && shift == grad_output_shift) { |
| WriteBpropResultsVectorized<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, shift, gradInput, output, gradOutput, epilogue); |
| } else { |
| WriteBpropResults<ILP, scalar_t, accscalar_t, outscalar_t, Epilogue>(classes, gradInput, output, gradOutput, epilogue); |
| } |
| } |
| |
| template<template<typename, typename, typename> class Epilogue, bool is_log_softmax> |
| Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){ |
| if (half_to_float) { |
| TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only"); |
| } |
| auto input = input_.contiguous(); |
| static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float"); |
| if (input.dim() == 0) input = input.view(1); |
| int64_t dim = maybe_wrap_dim(dim_, input.dim()); |
| TORCH_CHECK(dim >=0 && dim < input.dim(), "dim must be non-negative and less than input dimensions"); |
| int64_t outer_size = 1; |
| int64_t dim_size = input.size(dim); |
| |
| if (input.numel() > 0) { |
| int64_t inner_size = 1; |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| for (int64_t i = 0; i < dim; ++i) |
| outer_size *= input.size(i); |
| for (int64_t i = dim + 1; i < input.dim(); ++i) |
| inner_size *= input.size(i); |
| // This kernel spawns a block per each element in the batch. |
| // XXX: it assumes that inner_size == 1 |
| |
| if (inner_size == 1) { |
| dim3 grid(outer_size); |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { |
| using accscalar_t = acc_type<scalar_t, true>; |
| if (!half_to_float) { |
| if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { |
| auto output_ptr = output.data_ptr<scalar_t>(); |
| auto input_ptr = input.data_ptr<scalar_t>(); |
| int64_t remaining = outer_size; |
| int64_t chunk_size = (1L << 30L) / dim_size; |
| while(remaining > 0) { |
| dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false>( |
| output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */); |
| input_ptr += chunk_size * dim_size; |
| output_ptr += chunk_size * dim_size; |
| remaining -= chunk_size; |
| } |
| } else { |
| constexpr int ILP = sizeof(float4) / sizeof(scalar_t); |
| dim3 block = SoftMax_getBlockSize(ILP, dim_size); |
| cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue> |
| <<<grid, block, block.x * sizeof(accscalar_t), stream>>>( |
| output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), dim_size); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } else { |
| if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { |
| auto output_ptr = output.data_ptr<accscalar_t>(); |
| auto input_ptr = input.data_ptr<scalar_t>(); |
| int64_t remaining = outer_size; |
| int64_t chunk_size = (1<<30) / dim_size; |
| while(remaining > 0) { |
| dispatch_softmax_forward<scalar_t, accscalar_t, accscalar_t, is_log_softmax, false>( |
| output_ptr, input_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size), nullptr/* not masked */); |
| input_ptr += chunk_size * dim_size; |
| output_ptr += chunk_size * dim_size; |
| remaining -= chunk_size; |
| } |
| } else { |
| constexpr int ILP = sizeof(float4) / sizeof(accscalar_t); |
| dim3 block = SoftMax_getBlockSize(ILP, dim_size); |
| cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue> |
| <<<grid, block, block.x * sizeof(accscalar_t), stream>>>( |
| output.data_ptr<accscalar_t>(), input.data_ptr<scalar_t>(), dim_size); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } |
| }); |
| // This kernel runs in a 2D grid, where each application along y dimension has a fixed |
| // outer_size, and runs in parallel over inner_size. Dimension x is parallel over outer_size. |
| // Reductions over dim are done in a single-threaded manner. |
| } else { |
| uint32_t smem_size; |
| dim3 grid, block; |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "host_softmax", [&] { |
| using accscalar_t = acc_type<scalar_t, true>; |
| if (!half_to_float) { |
| SpatialSoftMax_getLaunchSizes<accscalar_t>( |
| &cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, Epilogue>, |
| outer_size, dim_size, inner_size, |
| grid, block, smem_size); |
| cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, scalar_t, Epilogue> |
| <<<grid, block, smem_size, stream>>>( |
| output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), outer_size, dim_size, inner_size); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } else { |
| SpatialSoftMax_getLaunchSizes<accscalar_t>( |
| &cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, Epilogue>, |
| outer_size, dim_size, inner_size, |
| grid, block, smem_size); |
| cunn_SpatialSoftMaxForward<scalar_t, accscalar_t, accscalar_t, Epilogue> |
| <<<grid, block, smem_size, stream>>>( |
| output.data_ptr<accscalar_t>(), input.data_ptr<scalar_t>(), outer_size, dim_size, inner_size); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| }); |
| } |
| } |
| return output; |
| } |
| |
| template<template<typename, typename, typename> class Epilogue, bool is_log_softmax> |
| void host_softmax_backward(const Tensor &grad_, const Tensor &output_, int64_t dim_, bool half_to_float, const Tensor &gI){ |
| int64_t dim = maybe_wrap_dim(dim_, grad_.dim()); |
| if (grad_.numel() == 0) { |
| return; |
| } |
| auto grad = grad_.contiguous(); |
| static_assert(std::is_same<acc_type<at::Half, true>, float>::value, "accscalar_t for half should be float"); |
| if (grad.dim() == 0) grad = grad.view(1); |
| TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions"); |
| auto output = output_.contiguous(); |
| if (output.dim() == 0) output = output.view(1); |
| int64_t outer_size = 1; |
| int64_t dim_size = output.size(dim); |
| int64_t inner_size = 1; |
| for (int64_t i = 0; i < dim; ++i) |
| outer_size *= output.size(i); |
| for (int64_t i = dim + 1; i < output.dim(); ++i) |
| inner_size *= output.size(i); |
| // See descriptions of kernels above. |
| cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |
| if (inner_size == 1) { |
| dim3 grid(outer_size); |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] { |
| using accscalar_t = acc_type<scalar_t, true>; |
| if (!half_to_float) { |
| if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { |
| auto gI_ptr = gI.data_ptr<scalar_t>(); |
| auto grad_ptr = grad.data_ptr<scalar_t>(); |
| auto output_ptr = output.data_ptr<scalar_t>(); |
| int64_t remaining = outer_size; |
| int64_t chunk_size = (1<<30) / dim_size; |
| while(remaining > 0) { |
| dispatch_softmax_backward<scalar_t, scalar_t, accscalar_t, is_log_softmax, false /* masked_softmax */>( |
| gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size)); |
| gI_ptr += chunk_size * dim_size; |
| grad_ptr += chunk_size * dim_size; |
| output_ptr += chunk_size * dim_size; |
| remaining -= chunk_size; |
| } |
| } else { |
| constexpr int ILP = sizeof(float4) / sizeof(scalar_t); |
| dim3 block = SoftMax_getBlockSize(ILP, dim_size); |
| cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue> |
| <<<grid, block, block.x * sizeof(accscalar_t), stream>>>( |
| gI.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), dim_size |
| ); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } else { |
| if (dim_size <= 1024 && dim_size*sizeof(scalar_t) <= 4096) { |
| auto gI_ptr = gI.data_ptr<scalar_t>(); |
| auto grad_ptr = grad.data_ptr<accscalar_t>(); |
| auto output_ptr = output.data_ptr<accscalar_t>(); |
| int64_t remaining = outer_size; |
| int64_t chunk_size = (1<<30) / dim_size; |
| while(remaining > 0) { |
| dispatch_softmax_backward<accscalar_t, scalar_t, accscalar_t, is_log_softmax, false /* masked_softmax */>( |
| gI_ptr, grad_ptr, output_ptr, dim_size, dim_size, std::min<int64_t>(remaining, chunk_size)); |
| gI_ptr += chunk_size * dim_size; |
| grad_ptr += chunk_size * dim_size; |
| output_ptr += chunk_size * dim_size; |
| remaining -= chunk_size; |
| } |
| } else { |
| constexpr int ILP = sizeof(float4) / sizeof(accscalar_t); |
| dim3 block = SoftMax_getBlockSize(ILP, dim_size); |
| cunn_SoftMaxBackward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue> |
| <<<grid, block, block.x * sizeof(accscalar_t), stream>>>( |
| gI.data_ptr<scalar_t>(), output.data_ptr<accscalar_t>(), grad.data_ptr<accscalar_t>(), dim_size |
| ); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| } |
| }); |
| } else { |
| uint32_t smem_size; |
| dim3 grid, block; |
| AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, gI.scalar_type(), "host_softmax_backward", [&] { |
| using accscalar_t = acc_type<scalar_t, true>; |
| if (!half_to_float) { |
| SpatialSoftMax_getLaunchSizes<accscalar_t>( |
| &cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue>, |
| outer_size, dim_size, inner_size, |
| grid, block, smem_size); |
| |
| cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, scalar_t, Epilogue> |
| <<<grid, block, smem_size, stream>>>( |
| gI.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(), grad.data_ptr<scalar_t>(), |
| outer_size, dim_size, inner_size |
| ); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } else { |
| SpatialSoftMax_getLaunchSizes<accscalar_t>( |
| &cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue>, |
| outer_size, dim_size, inner_size, |
| grid, block, smem_size); |
| |
| cunn_SpatialSoftMaxBackward<scalar_t, accscalar_t, accscalar_t, Epilogue> |
| <<<grid, block, smem_size, stream>>>( |
| gI.data_ptr<scalar_t>(), output.data_ptr<accscalar_t>(), grad.data_ptr<accscalar_t>(), |
| outer_size, dim_size, inner_size |
| ); |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| } |
| }); |
| } |
| } |
| } |
| |
| TORCH_IMPL_FUNC(log_softmax_cuda_out) ( |
| const Tensor &input, |
| const int64_t dim, |
| const bool half_to_float, |
| const Tensor &output) { |
| host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float, output); |
| } |
| |
| TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) ( |
| const Tensor& grad, |
| const Tensor& output, |
| int64_t dim, |
| ScalarType input_dtype, |
| const Tensor& grad_input) { |
| bool half_to_float = grad.scalar_type() != input_dtype; |
| if (half_to_float) { |
| TORCH_CHECK( |
| (grad.scalar_type() == ScalarType::Float && |
| input_dtype == ScalarType::Half), |
| "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); |
| } |
| host_softmax_backward<LogSoftMaxBackwardEpilogue, true>(grad, output, dim, half_to_float, grad_input); |
| } |
| |
| TORCH_IMPL_FUNC(softmax_cuda_out) ( |
| const Tensor &input, |
| const int64_t dim, |
| const bool half_to_float, |
| const Tensor &output) { |
| host_softmax<SoftMaxForwardEpilogue,false>(input, dim, half_to_float, output); |
| } |
| |
| TORCH_IMPL_FUNC(softmax_backward_cuda_out) |
| (const Tensor& grad, |
| const Tensor& output, |
| int64_t dim, |
| ScalarType input_dtype, |
| const Tensor& grad_input) { |
| bool half_to_float = grad.scalar_type() != input_dtype; |
| if (half_to_float) { |
| TORCH_CHECK( |
| (grad.scalar_type() == ScalarType::Float && |
| input_dtype == ScalarType::Half), |
| "expected input and grad types to match, or input to be at::Half and grad to be at::Float"); |
| } |
| Tensor tmp = grad * output; |
| host_softmax_backward<SoftMaxBackwardEpilogue, false>(tmp, output, dim, half_to_float, grad_input); |
| } |
| |
| Tensor masked_softmax_cuda(const Tensor& input_, const Tensor& mask_, const c10::optional<int64_t> dim_, const c10::optional<int64_t> mask_type_) { |
| Tensor output = at::empty_like(input_, input_.options()); |
| TORCH_CHECK(mask_.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor"); |
| |
| TORCH_CHECK(mask_type_.has_value(), "Mask Type should be defined"); |
| int64_t mask_type = mask_type_.value(); |
| TORCH_CHECK((mask_type == 0) || (mask_type == 1) || (mask_type == 2), "Mask Type should be 0 (src_mask), 1 (src_key_padding_mask), or 2 (default_mask)"); |
| |
| // If input is [B, H, T, T] and mask is [B, T] |
| // we have special fast kernel |
| // mask_type == 1 => mask_ is a src_key_padding_mask |
| bool is_BxT_mask = (mask_type == 1) && (input_.dim() == 4 && mask_.dim() == 2 && input_.size(0) == mask_.size(0) && input_.size(2) == mask_.size(1) && input_.size(3) == mask_.size(1)); |
| |
| // If input is [B, H, T, T] and mask is [T, T] |
| // expand mask to [B, H, T, T] and treat it like regular mask |
| // TODO We should have special fast kernel for TxT mask as well |
| // mask_type == 0 => mask_ is a src_mask |
| bool is_TxT_mask = (mask_type == 0) && input_.dim() == 4 && mask_.dim() == 2 && input_.size(3) == mask_.size(1) && input_.size(2) == mask_.size(0) && mask_.size(0) == mask_.size(1); |
| // If mask_type == 2, then mask_.sizes() must equal input_.sizes() |
| TORCH_CHECK(mask_.sizes() == input_.sizes() || is_BxT_mask || is_TxT_mask, "Mask shape should match input. mask: ", mask_.sizes(), " input: ", input_.sizes()); |
| |
| auto input = input_.dim() == 0 ? input_.view(1) : input_; |
| auto mask = mask_.dim() == 0 ? mask_.view(1) : mask_; |
| if (is_TxT_mask) { |
| mask = mask.expand(input.sizes()); |
| } |
| int64_t dim = dim_.has_value() ? dim_.value() : input.dim() - 1; |
| |
| int softmax_elements = input.size(dim); |
| // Persistent softmax is only supported when all of the conditions are held: |
| // 1) softmax_elements <= 1024 |
| // 2) softmax_elements * input.element_size() <= 4096 |
| // 3) mask.is_contiguous() |
| // 4) dim == input.dim() - 1 |
| // Otherwise, we fallback to vanilla softmax (where we do not support transformer_mask since converting the mask is expensive) |
| if (softmax_elements > 1024 || softmax_elements * input.element_size() > 4096 || !mask.is_contiguous() || dim < input.dim()-1) { |
| if (is_BxT_mask) { |
| mask = mask.view({mask_.size(0), 1, 1, mask_.size(1)}).expand(input.sizes()); |
| } |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| ScalarType::Half, |
| ScalarType::BFloat16, |
| input.scalar_type(), |
| "masked_softmax", |
| [&] { |
| output = at::softmax(input.masked_fill(mask, -std::numeric_limits<scalar_t>::infinity()), dim); |
| }); |
| return output; |
| } |
| int batch_count = input.numel() / softmax_elements; |
| int chunk_size = input.numel() / input.size(0); |
| if (is_BxT_mask) { |
| // Only support when num_heads is even in transformer |
| TORCH_CHECK(input.size(1) % 2 == 0, "Only support when num_heads is even in transformer"); |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| ScalarType::Half, |
| ScalarType::BFloat16, |
| input.scalar_type(), |
| "masked_softmax", |
| [&] { |
| using accscalar_t = acc_type<scalar_t, true>; |
| dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false/* is_log_softmax */, true/* is_masked */>( |
| output.data_ptr<scalar_t>(), // dst |
| input.data_ptr<scalar_t>(), // src |
| softmax_elements, |
| softmax_elements, |
| batch_count, |
| mask.data_ptr<bool>(), |
| chunk_size, |
| true // is_transformer_mask |
| ); |
| }); |
| |
| } else { |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| ScalarType::Half, |
| ScalarType::BFloat16, |
| input.scalar_type(), |
| "masked_softmax", |
| [&] { |
| using accscalar_t = acc_type<scalar_t, true>; |
| dispatch_softmax_forward<scalar_t, scalar_t, accscalar_t, false/* is_log_softmax */, true/* is_masked */>( |
| output.data_ptr<scalar_t>(), // dst |
| input.data_ptr<scalar_t>(), // src |
| softmax_elements, |
| softmax_elements, |
| batch_count, |
| mask.data_ptr<bool>() |
| ); |
| }); |
| } |
| return output; |
| } |
| |
| Tensor masked_softmax_backward_cuda( |
| const Tensor& grad_, |
| const Tensor& output_, |
| const Tensor& mask_, |
| const c10::optional<int64_t> dim_) { |
| Tensor grad_input = at::empty_like(grad_, grad_.options()); |
| if (grad_.numel() == 0) { |
| return grad_input; |
| } |
| |
| auto grad = grad_.contiguous(); |
| auto output = output_.contiguous(); |
| auto mask = mask_.contiguous(); |
| int64_t dim = dim_.has_value() ? maybe_wrap_dim(dim_.value(), output.dim()) : output.dim() - 1; |
| |
| grad = grad.dim() == 0 ? grad.view(1) : grad; |
| mask = mask.dim() == 0 ? mask.view(1) : mask; |
| output = output.dim() == 0 ? output.view(1) : output; |
| |
| TORCH_CHECK(dim >=0 && dim < grad.dim(), "dim must be non-negative and less than input dimensions"); |
| TORCH_CHECK(grad.sizes() == mask.sizes(), "Mask shape should match grad shape"); |
| TORCH_CHECK(mask.scalar_type() == ScalarType::Bool, "Mask should be a boolean tensor"); |
| |
| int softmax_elements = output.size(dim); |
| int64_t batch_count = grad.numel() / softmax_elements; |
| |
| if (softmax_elements > 1024 || softmax_elements * grad.element_size() > 4096 || dim < grad.dim()-1) { |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| ScalarType::Half, |
| ScalarType::BFloat16, |
| grad_input.scalar_type(), |
| "masked_softmax_backward", |
| [&] { |
| grad_input = at::_softmax_backward_data( |
| grad, |
| output.masked_fill(mask, 0), |
| dim, |
| grad.scalar_type() |
| ); |
| }); |
| } else { |
| grad = grad * output; |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| ScalarType::Half, |
| ScalarType::BFloat16, |
| grad_input.scalar_type(), |
| "masked_softmax_backward", |
| [&] { |
| using accscalar_t = acc_type<scalar_t, true>; |
| dispatch_softmax_backward<scalar_t, scalar_t, accscalar_t, false, true /* masked_softmax */>( |
| grad_input.data_ptr<scalar_t>(), // gI_ptr |
| grad.data_ptr<scalar_t>(), // grad_ptr |
| output.data_ptr<scalar_t>(), // output_ptr |
| softmax_elements, // softmax_elements |
| softmax_elements, // softmax_elements_stride |
| batch_count, // batch_count |
| mask.data_ptr<bool>() /* not masked */ |
| ); |
| }); |
| } |
| return grad_input; |
| } |
| } |
| } |