blob: 9873b4da5998a2ffb09ee16d3aec81737e24bd83 [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/AccumulateType.h>
#include <ATen/Dispatch.h>
#include <ATen/cuda/Atomic.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/NumericLimits.cuh>
#include <ATen/cuda/detail/IndexUtils.cuh>
#include <ATen/cuda/detail/TensorInfo.cuh>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/NumericUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/Utils.h>
#include <ATen/native/FractionalMaxPooling.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/fractional_max_pool3d_backward_native.h>
#include <ATen/ops/fractional_max_pool3d_native.h>
#endif
#include <algorithm>
#include <cfloat>
#include <cmath>
namespace at::native {
using namespace at::cuda::detail;
namespace {
template <typename scalar_t, typename accscalar_t>
__device__ inline int64_t get_intervals(
accscalar_t sample,
int64_t index,
int64_t inputSize,
int64_t outputSize,
int64_t poolSize) {
accscalar_t alpha = static_cast<accscalar_t>(inputSize - poolSize) /
static_cast<accscalar_t>(outputSize - 1);
if (index == outputSize - 1) {
return inputSize - poolSize;
} else {
return static_cast<int64_t>((index + sample) * alpha) - \
static_cast<int64_t>(sample * alpha);
}
}
template <typename scalar_t>
__global__ void fractional_max_pool3d_out_frame(
PackedTensorAccessor64<scalar_t, 5> input,
PackedTensorAccessor64<scalar_t, 5> output,
PackedTensorAccessor64<int64_t, 5> indices,
PackedTensorAccessor64<scalar_t, 3> samples,
int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) {
using accscalar_t = at::acc_type<scalar_t, /*is_cuda=*/true>;
// Output (t, h, w) point that this thread is responsible for
int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
int64_t plane = blockIdx.y;
int64_t batch = blockIdx.z;
// Each thread generates a specific output point
if (ourOutputPoint < output.size(2) * output.size(3) *
output.size(4)){
int64_t outputT = ourOutputPoint / (output.size(3) *
output.size(4));
int64_t outputH = (ourOutputPoint / output.size(4)) %
output.size(3);
int64_t outputW = ourOutputPoint % output.size(4);
int64_t poolT = get_intervals<scalar_t,accscalar_t>(
static_cast<accscalar_t>(samples[batch][plane][0]),
outputT, input.size(2), output.size(2), poolSizeT);
int64_t poolH = get_intervals<scalar_t, accscalar_t>(
static_cast<accscalar_t>(samples[batch][plane][1]),
outputH, input.size(3), output.size(3), poolSizeH);
int64_t poolW = get_intervals<scalar_t, accscalar_t>(
static_cast<accscalar_t>(samples[batch][plane][2]),
outputW, input.size(4), output.size(4), poolSizeW);
scalar_t maxVal = at::numeric_limits<scalar_t>::lower_bound();
int64_t maxIndex = poolT * input.size(3) * input.size(4) + poolH * input.size(4) + poolW;
for(int64_t t = poolT; t < poolT + poolSizeT; ++ t) {
for (int64_t h = poolH; h < poolH + poolSizeH; ++h) {
if(poolSizeW < 2 || poolSizeW > 7) {
for (int64_t w = poolW; w < poolW + poolSizeW; ++w) {
scalar_t val = input[batch][plane][t][h][w];
// for consistency with THNN, favor the first max
if (val > maxVal || at::_isnan(val)) {
maxIndex = t * input.size(3) *
input.size(4) + h * input.size(4) + w;
maxVal = val;
}
}
} else {
for (int64_t i = 0; i < poolSizeW; ++i) {
int64_t w = i + poolW;
scalar_t val = input[batch][plane][t][h][w];
// for consistency with THNN, favor the first max
if (val > maxVal || at::_isnan(val)) {
maxIndex = t * input.size(3) * input.size(4) +
h * input.size(4) + w;
maxVal = val;
}
}
}
}
}
indices[batch][plane][outputT][outputH][outputW] = maxIndex;
output[batch][plane][outputT][outputH][outputW] = maxVal;
}
}
template <typename scalar_t>
__global__ void fractional_max_pool3d_backward_out_frame(
PackedTensorAccessor64<scalar_t, 5> gradInput,
PackedTensorAccessor64<scalar_t, 5> gradOutput,
PackedTensorAccessor64<int64_t, 5> indices) {
// Output (h, w) point that this thread is responsible for
int64_t ourOutputPoint = threadIdx.x + blockIdx.x * blockDim.x;
int64_t plane = blockIdx.y;
int64_t batch = blockIdx.z;
// Each thread generates a specific output point
if (ourOutputPoint < gradOutput.size(2) *
gradOutput.size(3) * gradOutput.size(4)) {
int64_t outputW = ourOutputPoint % gradOutput.size(4);
int64_t outputH = (ourOutputPoint / gradOutput.size(4)) %
gradOutput.size(3);
int64_t outputT = ourOutputPoint / (gradOutput.size(3) *
gradOutput.size(4));
int64_t index = indices[batch][plane][outputT][outputH][outputW];
CUDA_KERNEL_ASSERT(index >= 0);
int64_t inputW = index % gradInput.size(4);
int64_t inputH = (index / gradInput.size(4)) %
gradInput.size(3);
int64_t inputT = index / (gradInput.size(3) *
gradInput.size(4));
CUDA_KERNEL_ASSERT(inputT < gradInput.size(2));
gpuAtomicAddNoReturn(
&gradInput[batch][plane][inputT][inputH][inputW],
gradOutput[batch][plane][outputT][outputH][outputW]
);
}
}
void fractional_max_pool3d_backward_out_cuda_template(
Tensor& gradInput,
const Tensor& gradOutput,
const Tensor& input,
IntArrayRef output_size,
const Tensor& indices) {
int64_t dimt = 1;
int64_t dimh = 2;
int64_t dimw = 3;
int64_t outputT = output_size[0];
int64_t outputH = output_size[1];
int64_t outputW = output_size[2];
int64_t ndims = input.ndimension();
if (ndims == 5) {
dimt++;
dimh++;
dimw++;
}
/* sizes */
int64_t inputT = input.size(dimt);
int64_t inputH = input.size(dimh);
int64_t inputW = input.size(dimw);
TORCH_CHECK(
outputT == gradOutput.size(dimt),
"fractional_max_pool3d_backward_out_cuda_template(): ",
"gradOutput time unexpected"
);
TORCH_CHECK(
outputH == gradOutput.size(dimh),
"fractional_max_pool3d_backward_out_cuda_template(): ",
"gradOutput height unexpected"
);
TORCH_CHECK(
outputW == gradOutput.size(dimw),
"fractional_max_pool3d_backward_out_cuda_template(): ",
"gradOutput width unexpected"
);
/* resize */
gradInput.resize_as_(input);
gradInput.zero_();
auto gradInput_ = gradInput;
auto gradOutput_ = gradOutput;
auto indices_ = indices;
if(ndims == 4) {
gradInput_ = gradInput_.reshape({1, gradInput.size(0), inputT,
inputH, inputW});
gradOutput_ = gradOutput_.reshape({1, gradOutput.size(0), outputT,
outputH, outputW});
indices_ = indices_.reshape({1, indices.size(0), outputT, outputH,
outputW});
}
if (gradInput.numel() == 0) {
return;
}
/* backprop */
// block is limited to 4 warps
// grid handles overflow per each plane
int64_t outputPlaneSize = gradOutput_.size(2) *
gradOutput_.size(3) * gradOutput_.size(4);
dim3 grid(
(outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
gradInput_.size(1),
gradInput_.size(0));
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
gradOutput.scalar_type(),
"fractional_max_pool3d_backward_out_frame",
[&] {
fractional_max_pool3d_backward_out_frame<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
gradInput_.packed_accessor64<scalar_t, 5>(),
gradOutput_.packed_accessor64<scalar_t, 5>(),
indices_.packed_accessor64<int64_t, 5>()
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
);
}
}// namespace
TORCH_IMPL_FUNC(fractional_max_pool3d_out_cuda) (
const Tensor& input,
int64_t poolSizeT,
int64_t poolSizeH,
int64_t poolSizeW,
int64_t outputT,
int64_t outputH,
int64_t outputW,
const Tensor& randomSamples,
int64_t numBatch,
int64_t numPlanes,
int64_t inputT,
int64_t inputH,
int64_t inputW,
const Tensor& output,
const Tensor& indices) {
fractional_max_pool_check_shape</*ndim*/ 3>(input, randomSamples);
auto output_ = output;
auto indices_ = indices;
auto input_ = input;
int64_t ndims = input_.ndimension();
if(ndims == 4) {
output_ = output_.reshape({1, numPlanes, outputT, outputH, outputW});
indices_ = indices_.reshape({1, numPlanes, outputT, outputH, outputW});
input_ = input_.reshape({1, numPlanes, inputT, inputH, inputW});
}
if (output_.numel() == 0) {
return;
}
// block is limited to 4 warps
// grid handles overflow per each plane
int64_t outputPlaneSize = output_.size(2) *
output_.size(3) * output_.size(4);
dim3 grid(
(outputPlaneSize + 127) / 128, // ceil(outputPlaneSize / 128)
input_.size(1),
input_.size(0));
dim3 block(outputPlaneSize > 128 ? 128 : outputPlaneSize);
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
input.scalar_type(),
"fractional_max_pool3d_out_frame",
[&]{
fractional_max_pool3d_out_frame<scalar_t>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
input_.packed_accessor64<scalar_t, 5>(),
output_.packed_accessor64<scalar_t, 5>(),
indices_.packed_accessor64<int64_t, 5>(),
randomSamples.packed_accessor64<scalar_t, 3>(),
poolSizeT, poolSizeH, poolSizeW
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
);
}
Tensor& fractional_max_pool3d_backward_out_cuda(const at::Tensor& gradOutput_,
const at::Tensor& input,
IntArrayRef /*pool_size*/,
IntArrayRef output_size,
const at::Tensor& indices,
at::Tensor& gradInput) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("fractional_max_pool3d_backward_out_cuda");
fractional_max_pool3d_backward_out_cuda_template(
gradInput,
gradOutput_,
input,
output_size,
indices
);
return gradInput;
}
Tensor fractional_max_pool3d_backward_cuda(
const at::Tensor& gradOutput,
const at::Tensor& input,
IntArrayRef pool_size,
IntArrayRef output_size,
const at::Tensor& indices) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("fractional_max_pool3d_backward_cuda");
Tensor gradInput = at::empty({0}, input.options());
fractional_max_pool3d_backward_out_cuda_template(
gradInput,
gradOutput,
input,
output_size,
indices
);
return gradInput;
}
}// namespace at::native