blob: a626c82b5e6ea4dc7f794cc746c38048305cec8b [file] [log] [blame]
/*
* CuDNN ReLU extension. Simple function but contains the general structure of
* most CuDNN extensions:
* 1) Check arguments. torch::check* functions provide a standard way to
* validate input and provide pretty errors. 2) Create descriptors. Most CuDNN
* functions require creating and setting a variety of descriptors. 3) Apply the
* CuDNN function. 4) Destroy your descriptors. 5) Return something (optional).
*/
#include <torch/extension.h>
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
#include <ATen/cudnn/Handle.h> // for getCudnnHandle
// Name of function in python module and name used for error messages by
// torch::check* functions.
const char* cudnn_relu_name = "cudnn_relu";
// Check arguments to cudnn_relu
void cudnn_relu_check(
const torch::Tensor& inputs,
const torch::Tensor& outputs) {
// Create TensorArgs. These record the names and positions of each tensor as a
// parameter.
torch::TensorArg arg_inputs(inputs, "inputs", 0);
torch::TensorArg arg_outputs(outputs, "outputs", 1);
// Check arguments. No need to return anything. These functions with throw an
// error if they fail. Messages are populated using information from
// TensorArgs.
torch::checkContiguous(cudnn_relu_name, arg_inputs);
torch::checkScalarType(cudnn_relu_name, arg_inputs, torch::kFloat);
torch::checkBackend(cudnn_relu_name, arg_inputs.tensor, torch::Backend::CUDA);
torch::checkContiguous(cudnn_relu_name, arg_outputs);
torch::checkScalarType(cudnn_relu_name, arg_outputs, torch::kFloat);
torch::checkBackend(
cudnn_relu_name, arg_outputs.tensor, torch::Backend::CUDA);
torch::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
}
void cudnn_relu(const torch::Tensor& inputs, const torch::Tensor& outputs) {
// Most CuDNN extensions will follow a similar pattern.
// Step 1: Check inputs. This will throw an error if inputs are invalid, so no
// need to check return codes here.
cudnn_relu_check(inputs, outputs);
// Step 2: Create descriptors
cudnnHandle_t cuDnn = torch::native::getCudnnHandle();
// Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
// size and type and contiguous, so one descriptor is sufficient.
torch::native::TensorDescriptor input_tensor_desc(inputs, 4);
cudnnActivationDescriptor_t activationDesc;
// Note: Always check return value of cudnn functions using CUDNN_CHECK
AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
activationDesc,
/*mode=*/CUDNN_ACTIVATION_RELU,
/*reluNanOpt=*/CUDNN_PROPAGATE_NAN,
/*coef=*/1.));
// Step 3: Apply CuDNN function
float alpha = 1.;
float beta = 0.;
AT_CUDNN_CHECK(cudnnActivationForward(
cuDnn,
activationDesc,
&alpha,
input_tensor_desc.desc(),
inputs.data_ptr(),
&beta,
input_tensor_desc.desc(), // output descriptor same as input
outputs.data_ptr()));
// Step 4: Destroy descriptors
AT_CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc));
// Step 5: Return something (optional)
}
// Create the pybind11 module
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Use the same name as the check functions so error messages make sense
m.def(cudnn_relu_name, &cudnn_relu, "CuDNN ReLU");
}