blob: 9365f9a34ea76ac30c39cd9a501c9436c6051b1a [file] [log] [blame]
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/TypeDefault.h>
#include <ATen/native/ForeachUtils.h>
#include <c10/util/Exception.h>
#include <ATen/native/cuda/fused_adam_amsgrad_impl.cuh>
#include <ATen/native/cuda/fused_adam_impl.cuh>
namespace at::native {
// note(crcrpar): To observe the CI rules, i.e. 20 minutes per file to compile,
// defensively split instantiations into _impl files. this is only for CUDA 11.3
// for which it took about 20 minutes and 28 minutes in my workstation and CI,
// respectively. As a data point, it took about 20 seconds for CUDA 11.7
// installed in my environment. See
// https://github.com/pytorch/pytorch/pull/81705 for details.
void _fused_adam_kernel_cuda_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const double lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
if (amsgrad) {
TORCH_CHECK(
at::native::check_fast_path_restrictions(
{params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_amsgrad_cuda_impl_(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
} else {
TORCH_CHECK(
at::native::check_fast_path_restrictions(
{params, grads, exp_avgs, exp_avg_sqs}),
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_cuda_impl_(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
}
}
// The following overload simply has a Tensor lr
void _fused_adam_kernel_cuda_(
at::TensorList params,
at::TensorList grads,
at::TensorList exp_avgs,
at::TensorList exp_avg_sqs,
at::TensorList max_exp_avg_sqs,
at::TensorList state_steps,
const at::Tensor& lr,
const double beta1,
const double beta2,
const double weight_decay,
const double eps,
const bool amsgrad,
const bool maximize,
const c10::optional<at::Tensor>& grad_scale,
const c10::optional<at::Tensor>& found_inf) {
if (lr.is_cpu()) {
_fused_adam_kernel_cuda_(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr.item<double>(),
beta1,
beta2,
weight_decay,
eps,
amsgrad,
maximize,
grad_scale,
found_inf);
return;
}
// Manually check devices since we specify no device check in
// native_functions.yaml
Device param_device = params[0].device();
if (grad_scale != c10::nullopt) {
TORCH_CHECK(
grad_scale->device() == param_device,
"grad_scale must be on the same GPU device as the params");
}
if (found_inf != c10::nullopt) {
TORCH_CHECK(
found_inf->device() == param_device,
"found_inf must be on the same GPU device as the params");
}
TORCH_CHECK(
lr.device() == param_device,
"lr must be on the same GPU device as the params");
if (amsgrad) {
TORCH_CHECK(
at::native::check_fast_path_restrictions(
{params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs}),
"params, grads, exp_avgs, exp_avg_sqs, and max_exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_amsgrad_cuda_impl_(
params,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
} else {
TORCH_CHECK(
at::native::check_fast_path_restrictions(
{params, grads, exp_avgs, exp_avg_sqs}),
"params, grads, exp_avgs, and exp_avg_sqs must have same dtype, device, and layout");
_fused_adam_cuda_impl_(
params,
grads,
exp_avgs,
exp_avg_sqs,
state_steps,
lr,
beta1,
beta2,
weight_decay,
eps,
maximize,
grad_scale,
found_inf);
}
}
} // namespace at::native