| #define TORCH_ASSERT_ONLY_METHOD_OPERATORS |
| #include <ATen/core/Tensor.h> |
| #include <ATen/cuda/CUDAContext.h> |
| #include <ATen/MemoryOverlap.h> |
| #include <ATen/cuda/detail/IndexUtils.cuh> |
| #include <ATen/native/Resize.h> |
| #include <ATen/native/TypeProperties.h> |
| #include <ATen/native/TensorShape.h> |
| #include <ATen/Dispatch.h> |
| #include <c10/core/MemoryFormat.h> |
| #include <c10/util/Optional.h> |
| |
| #ifndef AT_PER_OPERATOR_HEADERS |
| #include <ATen/Functions.h> |
| #include <ATen/NativeFunctions.h> |
| #else |
| #include <ATen/ops/cat_native.h> |
| #include <ATen/ops/copy_native.h> |
| #include <ATen/ops/empty.h> |
| #include <ATen/ops/empty_like.h> |
| #include <ATen/ops/narrow.h> |
| #endif |
| |
| namespace at { |
| namespace native { |
| |
| constexpr int CAT_ARRAY_BATCH_SIZE = 128; |
| constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4; |
| |
| namespace { |
| |
| inline bool getCatGrid(ptrdiff_t nTensors, dim3& grid) { |
| const int numSM = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; |
| |
| //X dim of grid for cat array cooperates on a single tensor in the cat. |
| //Given half of the GPU, full utilization will always occur. |
| grid = dim3( 2LL * numSM, (long long) nTensors ); |
| |
| return true; |
| } |
| |
| // Similar to any other IndexToOffset calculation for copying along a given |
| // dimension. |
| template <typename IndexType, int Dims> |
| struct CatArrIndexToOffset { |
| static inline __device__ IndexType compute( |
| const IndexType tensorSize[Dims], |
| const IndexType tensorStride[Dims], |
| const IndexType dimSize, |
| const unsigned int concatDim, |
| IndexType linearIndex) { |
| // linearIndex is not really linear index, but instead the offset in |
| // input tensor. If the input tensor is contiguous, then this offset |
| // is the linear index, but if the input tensor is channels last, then |
| // it is the linear index of the permuted contiguous tensor |
| IndexType offset = 0; |
| |
| #pragma unroll |
| for (int i = Dims - 1; i >= 1; --i) { |
| IndexType curDimSize = i == concatDim ? dimSize : tensorSize[i]; |
| IndexType nextDimIndex = linearIndex / curDimSize; |
| IndexType curDimIndex = linearIndex - curDimSize * nextDimIndex; |
| IndexType curDimOffset = curDimIndex * tensorStride[i]; |
| offset += curDimOffset; |
| linearIndex = nextDimIndex; |
| } |
| |
| return offset + linearIndex * tensorStride[0]; |
| } |
| }; |
| |
| template<typename IndexType, unsigned int MaxDims> |
| struct TensorSizeStride { |
| IndexType tensorSize[MaxDims]; |
| IndexType tensorStride[MaxDims]; |
| }; |
| |
| /** |
| * Kernel used to concatenated grimDim.y tensors into an output tensor. Uses a |
| * grid-stride loop based off of the blockIdx.x, threadIdx.x for each input to |
| * copy each element from each input tensor into the output. |
| * |
| * output: base pointer to the storage associated with the output tensor |
| * inputs: GPU-allocated array of input metadata for each input to concatenate |
| * in the kernel |
| * os: the size/stride vectors for the output tensor |
| * concatDim: dimension along which we are concatenating |
| * dimStride: the stride of the output tensor at the concatDim |
| * |
| * The most important assumption made is that the input tensors are contiguous. |
| */ |
| |
| |
| // pass meta data directly through kernel argument instead of pin memory |
| // In contiguous case, we will not need stride_size, setting it as 1 as placeholder |
| // to pass compile. |
| template <typename T, typename IndexType, int n, int stride_size> |
| struct CatArrInputTensorMetadata { |
| T* input[n]; |
| IndexType offset[n]; |
| IndexType dimSize[n]; |
| IndexType nElements[n]; |
| bool isContiguous[n]; |
| TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> tensorStride[stride_size]; |
| }; |
| |
| template <typename T, typename IndexType, int Dims, int batch_size, int stride_size> |
| __global__ void CatArrayBatchedCopy( |
| T* output, |
| CatArrInputTensorMetadata<T, IndexType, batch_size, stride_size> inputs, |
| TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os, |
| const int concatDim, |
| IndexType dimStride) { |
| |
| IndexType tid = blockIdx.x * blockDim.x + threadIdx.x; |
| IndexType nElements = inputs.nElements[blockIdx.y]; |
| TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> ins = stride_size > 1 ? inputs.tensorStride[blockIdx.y] : inputs.tensorStride[0]; |
| bool isContig = inputs.isContiguous[blockIdx.y]; |
| |
| if(tid >= nElements) return; |
| |
| T* data = inputs.input[blockIdx.y]; |
| IndexType offset = inputs.offset[blockIdx.y]; |
| IndexType dimSize = inputs.dimSize[blockIdx.y]; |
| IndexType dataOffset = offset * dimStride; |
| |
| IndexType stride = gridDim.x * blockDim.x; |
| |
| while( tid < nElements){ |
| IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute( |
| os.tensorSize, os.tensorStride, dimSize, concatDim, tid); |
| if (isContig) { |
| output[dataOffset + elementOffset] = data[tid]; |
| } else { |
| IndexType inElementOffset = CatArrIndexToOffset<IndexType, Dims>::compute( |
| ins.tensorSize, ins.tensorStride, dimSize, concatDim, tid); |
| output[dataOffset + elementOffset] = data[inElementOffset]; |
| } |
| tid += stride; |
| } |
| } |
| |
| template <typename scalar_t, int batch_size, int stride_size> |
| void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, int64_t dimension, |
| int nDims, c10::MemoryFormat memory_format) { |
| // First, let's set up our kernel parameters. We start with a raw pointer to |
| // the storage for the output Tensor. |
| scalar_t *data = out.data_ptr<scalar_t>(); |
| CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData; |
| TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam; |
| |
| // Next, let's initialize the size, stride arrays for the output Tensor. |
| if (memory_format == c10::MemoryFormat::Contiguous) { |
| for (int i = 0; i < nDims; ++i) { |
| outputParam.tensorSize[i] = out.size(i); |
| outputParam.tensorStride[i] = out.stride(i); |
| } |
| } else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) { |
| // permute the semantics of dims from NCHW to NHWC so that the input |
| // tensor is now contiguous |
| outputParam.tensorSize[0] = out.size(0); |
| outputParam.tensorStride[0] = out.stride(0); |
| for (int i = 1; i < nDims - 1; ++i) { |
| outputParam.tensorSize[i] = out.size(i + 1); |
| outputParam.tensorStride[i] = out.stride(i + 1); |
| } |
| outputParam.tensorSize[nDims - 1] = out.size(1); |
| outputParam.tensorStride[nDims - 1] = out.stride(1); |
| } else { |
| TORCH_CHECK(false, "unsupported memory format"); |
| } |
| |
| at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); |
| |
| // Now we loop |
| int batchCounter = 0; |
| int64_t offset = 0; |
| for (int i = 0; i < inputs.size() ; i += batch_size) { |
| for (batchCounter = 0; |
| batchCounter < batch_size && |
| (i+batchCounter) < inputs.size(); |
| ++batchCounter) { |
| int64_t dimSize = 0; |
| // There is a legacy case where a 1-D empty tensor can be concat with |
| // high-dimensional tensor |
| if (inputs[i+batchCounter].get().numel() > 0) { |
| dimSize = inputs[i+batchCounter].get().size(dimension); |
| } |
| catMetaData.input[batchCounter] = inputs[i+batchCounter].get().data_ptr<scalar_t>(); |
| catMetaData.offset[batchCounter] = offset; |
| catMetaData.dimSize[batchCounter] = dimSize; |
| catMetaData.nElements[batchCounter] = inputs[i+batchCounter].get().numel(); |
| if (stride_size > 1) { |
| auto strides = inputs[i+batchCounter].get().strides(); |
| auto sizes = inputs[i+batchCounter].get().sizes(); |
| for(int j = 0; j < nDims; j++){ |
| catMetaData.tensorStride[batchCounter].tensorSize[j] = sizes[j]; |
| catMetaData.tensorStride[batchCounter].tensorStride[j] = strides[j]; |
| } |
| catMetaData.isContiguous[batchCounter] = false; |
| } else { |
| catMetaData.isContiguous[batchCounter] = true; |
| } |
| // update offset |
| offset += dimSize; |
| } |
| // Next, let's consider how we set our kernel launch parameters. |
| // We borrow from THCApply, which the kernel's internal indexing |
| // is based on. |
| dim3 applyBlock = dim3(32*16); |
| |
| //Get grid where x dim fills half gpu and y dim is number of tensors. |
| //This will have cating two tensors fill the entire grid, but prevent |
| //many threads from needlessly load meta data if their sizes is small. |
| dim3 catGrid; |
| getCatGrid(batchCounter, catGrid); |
| |
| if (memory_format != c10::MemoryFormat::Contiguous) { |
| switch (dimension) { |
| case 0: |
| break; |
| case 1: |
| dimension = nDims - dimension; |
| break; |
| default: |
| dimension--; |
| } |
| } |
| // Template Declarations for dim = 1, 2, 3, 4 |
| #define HANDLE_CASE(DIMS) \ |
| CatArrayBatchedCopy<scalar_t, unsigned int, DIMS, batch_size, stride_size><<<\ |
| catGrid, applyBlock, 0, stream.stream()>>>(\ |
| data, catMetaData, outputParam, dimension, outputParam.tensorStride[dimension]); \ |
| C10_CUDA_KERNEL_LAUNCH_CHECK(); |
| switch (nDims) { |
| case 1: |
| HANDLE_CASE(1); |
| break; |
| case 2: |
| HANDLE_CASE(2); |
| break; |
| case 3: |
| HANDLE_CASE(3); |
| break; |
| case 4: |
| HANDLE_CASE(4); |
| break; |
| } |
| #undef HANDLE_CASE |
| } |
| } |
| } // namespace |
| |
| TORCH_IMPL_FUNC(cat_out_cuda) |
| (const ITensorListRef& tensors, |
| int64_t dim, |
| int64_t valid, |
| bool all_contiguous, |
| bool all_same_dtype, |
| bool all_same_sizes_and_stride, |
| MemoryFormat memory_format, |
| const Tensor& result) { |
| if (result.numel() == 0) { |
| return; |
| } |
| |
| auto materialized = tensors.materialize(); |
| |
| // We parallelize the copy if all 6 conditions pass: |
| // |
| // 1. There is more than one input tensor |
| // 2. The out tensor is 32-bit indexable |
| // 3. The number of dimensions is <= 4 |
| // 4. All input tensors are contiguous (output tensor may be non-contig) |
| // 5. All input tensors can use 32-bit indexing |
| |
| const bool all32BitIndexable = std::all_of(materialized.begin(), materialized.end(), |
| [] (const Tensor& t) { |
| return at::cuda::detail::canUse32BitIndexMath(t); |
| }); |
| |
| int nDims = materialized[valid].get().dim(); |
| |
| // We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways |
| // For contiguous input, we don't need to pass stride meta data to cuda kernel through constant |
| // memory. Therefore, we could pass more inputs to cuda threads. |
| // For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation |
| // of constant memory. |
| if (materialized.size() > 1 && |
| result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && |
| at::cuda::detail::canUse32BitIndexMath(result) && |
| all_contiguous && |
| all32BitIndexable && |
| all_same_dtype) { |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| kComplexHalf, kHalf, kBool, kBFloat16, |
| result.scalar_type(), "cat_cuda", [&]() { |
| parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format); |
| }); |
| } else if (materialized.size() > 1 && |
| result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && |
| at::cuda::detail::canUse32BitIndexMath(result) && |
| nDims <= CAT_ARRAY_MAX_INPUT_DIMS && |
| all32BitIndexable && |
| all_same_dtype && |
| memory_format == c10::MemoryFormat::Contiguous) { |
| AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( |
| kComplexHalf, kHalf, kBool, kBFloat16, |
| result.scalar_type(), "cat_cuda", [&]() { |
| parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format); |
| }); |
| } else { |
| int64_t offset = 0; |
| for (const Tensor& t : materialized) { |
| if (cat_should_skip_tensor(t)) continue; |
| int64_t dimSize = t.size(dim); |
| Tensor nt = at::narrow(result, dim, offset, dimSize); |
| copy_(nt, t); |
| offset += dimSize; |
| } |
| } |
| } |
| |
| } // namespace native |
| } // namespace at |