| #include <ATen/Dispatch.h> |
| #include <ATen/OpMathType.h> |
| #include <ATen/core/Tensor.h> |
| #include <ATen/native/ForeachUtils.h> |
| #include <c10/util/Exception.h> |
| #include <ATen/native/cuda/ForeachFunctors.cuh> |
| #include <ATen/native/cuda/MultiTensorApply.cuh> |
| |
| namespace at::native { |
| |
| namespace { |
| |
| template <typename scalar_t, int depth> |
| C10_DEVICE __forceinline__ void sgd_math( |
| scalar_t r_args[depth][kILP], |
| const double weight_decay, |
| const double momentum, |
| const float* lr_ptr, |
| const double lr, |
| const double dampening, |
| const bool nesterov, |
| const bool maximize, |
| const bool is_first_step, |
| const float* grad_scale_ptr) { |
| using opmath_t = at::opmath_type<scalar_t>; |
| const double double_lr = lr_ptr != nullptr ? *lr_ptr : lr; |
| #pragma unroll |
| for (int ii = 0; ii < kILP; ii++) { |
| auto p = static_cast<opmath_t>(r_args[0][ii]); |
| auto g = static_cast<opmath_t>(r_args[1][ii]); |
| if (grad_scale_ptr) { |
| g /= static_cast<double>(*grad_scale_ptr); |
| r_args[1][ii] = g; |
| } |
| if (maximize) { |
| g *= -1.0; |
| } |
| if (weight_decay != 0) { |
| g += weight_decay * p; |
| } |
| if (depth > 2) { |
| const auto momentum_buffer = is_first_step |
| ? g |
| : (momentum * static_cast<opmath_t>(r_args[2][ii]) + |
| (1 - dampening) * g); |
| r_args[2][ii] = momentum_buffer; |
| |
| if (nesterov) { |
| g = g + momentum * momentum_buffer; |
| } else { |
| g = momentum_buffer; |
| } |
| } |
| p -= double_lr * g; |
| r_args[0][ii] = p; |
| } |
| } |
| |
| template <typename scalar_t, int depth> |
| struct FusedSgdMathFunctor { |
| static_assert( |
| depth == 2 || depth == 3, |
| "depth of 2 for SGD w/ momentum == 0, 3 for SGD w/ momentum != 0"); |
| C10_DEVICE __forceinline__ void operator()( |
| const int chunk_size, |
| TensorListMetadata<depth>& tl, |
| const double weight_decay, |
| const double momentum, |
| const float* lr_ptr, |
| const double lr, |
| const double dampening, |
| const bool nesterov, |
| const bool maximize, |
| const bool is_first_step, |
| const float* grad_scale_ptr, |
| const float* found_inf_ptr) { |
| if (found_inf_ptr && *found_inf_ptr == 1) { |
| return; |
| } |
| const auto tensor_loc = tl.block_to_tensor[blockIdx.x]; |
| const auto chunk_idx = tl.block_to_chunk[blockIdx.x]; |
| |
| scalar_t* args[depth]; |
| scalar_t r_args[depth][kILP]; |
| const auto all_aligned{ |
| init_args<depth>(args, tl, chunk_idx, chunk_size, tensor_loc)}; |
| const auto n = tl.numel_for_tensor[tensor_loc] - chunk_idx * chunk_size; |
| |
| #ifndef USE_ROCM |
| const auto use_faster_load_store = |
| (n % kILP == 0) && (chunk_size % kILP == 0) && all_aligned; |
| #else |
| const auto use_faster_load_store{false}; |
| #endif |
| if (use_faster_load_store) { |
| for (auto i_start = threadIdx.x; |
| i_start * kILP < n && i_start * kILP < chunk_size; |
| i_start += blockDim.x) { |
| #pragma unroll |
| for (auto i = 0; i < depth; i++) { |
| load_store(r_args[i], args[i], 0, i_start); |
| } |
| sgd_math<scalar_t, depth>( |
| r_args, |
| weight_decay, |
| momentum, |
| lr_ptr, |
| lr, |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale_ptr); |
| load_store(args[0], r_args[0], i_start, 0); |
| if (grad_scale_ptr) { |
| load_store(args[1], r_args[1], i_start, 0); |
| } |
| if (depth > 2) { |
| load_store(args[2], r_args[2], i_start, 0); |
| } |
| } |
| } else { |
| for (auto i_start = 0; i_start < n && i_start < chunk_size; |
| i_start += blockDim.x * kILP) { |
| load_args<depth>(r_args, args, i_start, chunk_size, n); |
| sgd_math<scalar_t, depth>( |
| r_args, |
| weight_decay, |
| momentum, |
| lr_ptr, |
| lr, |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale_ptr); |
| store_args(args[0], r_args[0], i_start, chunk_size, n); |
| if (grad_scale_ptr) { |
| store_args(args[1], r_args[1], i_start, chunk_size, n); |
| } |
| if (depth > 2) { |
| store_args(args[2], r_args[2], i_start, chunk_size, n); |
| } |
| } |
| } |
| } |
| }; |
| |
| void _fused_sgd_with_momentum_kernel_cuda_( |
| at::TensorList params, |
| at::TensorList grads, |
| at::TensorList momentum_buffer_list, |
| const double weight_decay, |
| const double momentum, |
| const double lr, |
| const double dampening, |
| const bool nesterov, |
| const bool maximize, |
| const bool is_first_step, |
| const c10::optional<at::Tensor>& grad_scale, |
| const c10::optional<at::Tensor>& found_inf) { |
| TORCH_CHECK_GT(momentum, 0); |
| TORCH_CHECK(at::native::check_fast_path_restrictions( |
| {params, grads, momentum_buffer_list})); |
| float* grad_scale_ptr = |
| grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr; |
| float* found_inf_ptr = |
| found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr; |
| float* lr_ptr = nullptr; |
| |
| std::vector<std::vector<at::Tensor>> tensor_lists{ |
| params.vec(), grads.vec(), momentum_buffer_list.vec()}; |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| kHalf, |
| kBFloat16, |
| params[0].scalar_type(), |
| "fused_sgd_with_momentum_kernel_cuda", |
| [&]() { |
| multi_tensor_apply<3>( |
| tensor_lists, |
| FusedSgdMathFunctor<scalar_t, 3>(), |
| weight_decay, |
| momentum, |
| lr_ptr, |
| lr, |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale_ptr, |
| found_inf_ptr); |
| }); |
| } |
| |
| void _fused_sgd_with_momentum_kernel_cuda_( |
| at::TensorList params, |
| at::TensorList grads, |
| at::TensorList momentum_buffer_list, |
| const double weight_decay, |
| const double momentum, |
| const at::Tensor& lr, |
| const double dampening, |
| const bool nesterov, |
| const bool maximize, |
| const bool is_first_step, |
| const c10::optional<at::Tensor>& grad_scale, |
| const c10::optional<at::Tensor>& found_inf) { |
| if (lr.is_cpu()) { |
| _fused_sgd_with_momentum_kernel_cuda_( |
| params, |
| grads, |
| momentum_buffer_list, |
| weight_decay, |
| momentum, |
| lr.item<double>(), |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale, |
| found_inf); |
| return; |
| } |
| TORCH_CHECK_GT(momentum, 0); |
| TORCH_CHECK(at::native::check_fast_path_restrictions( |
| {params, grads, momentum_buffer_list})); |
| if (grad_scale != c10::nullopt) { |
| TORCH_CHECK( |
| grad_scale->device() == params[0].device(), |
| "grad_scale must be on the same GPU device as the params"); |
| } |
| if (found_inf != c10::nullopt) { |
| TORCH_CHECK( |
| found_inf->device() == params[0].device(), |
| "found_inf must be on the same GPU device as the params"); |
| } |
| TORCH_CHECK( |
| lr.device() == params[0].device(), |
| "found_inf must be on the same GPU device as the params"); |
| float* grad_scale_ptr = |
| grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr; |
| float* found_inf_ptr = |
| found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr; |
| |
| std::vector<std::vector<at::Tensor>> tensor_lists{ |
| params.vec(), grads.vec(), momentum_buffer_list.vec()}; |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| kHalf, |
| kBFloat16, |
| params[0].scalar_type(), |
| "fused_sgd_with_momentum_kernel_cuda", |
| [&]() { |
| multi_tensor_apply<3>( |
| tensor_lists, |
| FusedSgdMathFunctor<scalar_t, 3>(), |
| weight_decay, |
| momentum, |
| lr.data_ptr<float>(), |
| 1.0, |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale_ptr, |
| found_inf_ptr); |
| }); |
| } |
| |
| } // namespace |
| |
| void _fused_sgd_kernel_cuda_( |
| at::TensorList params, |
| at::TensorList grads, |
| at::TensorList momentum_buffer_list, |
| const double weight_decay, |
| const double momentum, |
| const double lr, |
| const double dampening, |
| const bool nesterov, |
| const bool maximize, |
| const bool is_first_step, |
| const c10::optional<at::Tensor>& grad_scale, |
| const c10::optional<at::Tensor>& found_inf) { |
| if (!momentum_buffer_list.empty()) { |
| _fused_sgd_with_momentum_kernel_cuda_( |
| params, |
| grads, |
| momentum_buffer_list, |
| weight_decay, |
| momentum, |
| lr, |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale, |
| found_inf); |
| return; |
| } |
| TORCH_CHECK_EQ(momentum, 0); |
| TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads})); |
| if (is_first_step) { |
| TORCH_WARN_ONCE( |
| "`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); |
| } |
| float* grad_scale_ptr = |
| grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr; |
| float* found_inf_ptr = |
| found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr; |
| float* lr_ptr = nullptr; |
| |
| std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec()}; |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| kHalf, |
| kBFloat16, |
| params[0].scalar_type(), |
| "fused_sgd_kernel_cuda", |
| [&]() { |
| multi_tensor_apply<2>( |
| tensor_lists, |
| FusedSgdMathFunctor<scalar_t, 2>(), |
| weight_decay, |
| momentum, |
| lr_ptr, |
| lr, |
| dampening, |
| nesterov, |
| maximize, |
| /* is_first_step */ false, |
| grad_scale_ptr, |
| found_inf_ptr); |
| }); |
| } |
| |
| void _fused_sgd_kernel_cuda_( |
| at::TensorList params, |
| at::TensorList grads, |
| at::TensorList momentum_buffer_list, |
| const double weight_decay, |
| const double momentum, |
| const at::Tensor& lr, |
| const double dampening, |
| const bool nesterov, |
| const bool maximize, |
| const bool is_first_step, |
| const c10::optional<at::Tensor>& grad_scale, |
| const c10::optional<at::Tensor>& found_inf) { |
| if (!momentum_buffer_list.empty()) { |
| _fused_sgd_with_momentum_kernel_cuda_( |
| params, |
| grads, |
| momentum_buffer_list, |
| weight_decay, |
| momentum, |
| lr, |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale, |
| found_inf); |
| return; |
| } |
| if (lr.is_cpu()) { |
| _fused_sgd_kernel_cuda_( |
| params, |
| grads, |
| momentum_buffer_list, |
| weight_decay, |
| momentum, |
| lr.item<double>(), |
| dampening, |
| nesterov, |
| maximize, |
| is_first_step, |
| grad_scale, |
| found_inf); |
| return; |
| } |
| TORCH_CHECK_EQ(momentum, 0); |
| TORCH_CHECK(at::native::check_fast_path_restrictions({params, grads})); |
| if (is_first_step) { |
| TORCH_WARN_ONCE( |
| "`is_first_step` argument has no effect when `momentum_buffer_list` is empty"); |
| } |
| if (grad_scale.has_value()) { |
| TORCH_CHECK( |
| grad_scale->device() == params[0].device(), |
| "grad_scale must be on the same GPU device as the params"); |
| } |
| if (found_inf.has_value()) { |
| TORCH_CHECK( |
| found_inf->device() == params[0].device(), |
| "found_inf must be on the same GPU device as the params"); |
| } |
| TORCH_CHECK( |
| lr.device() == params[0].device(), |
| "found_inf must be on the same GPU device as the params"); |
| float* grad_scale_ptr = |
| grad_scale.has_value() ? grad_scale->data_ptr<float>() : nullptr; |
| float* found_inf_ptr = |
| found_inf.has_value() ? found_inf->data_ptr<float>() : nullptr; |
| |
| std::vector<std::vector<at::Tensor>> tensor_lists{params.vec(), grads.vec()}; |
| AT_DISPATCH_FLOATING_TYPES_AND2( |
| kHalf, |
| kBFloat16, |
| params[0].scalar_type(), |
| "fused_sgd_kernel_cuda", |
| [&]() { |
| multi_tensor_apply<2>( |
| tensor_lists, |
| FusedSgdMathFunctor<scalar_t, 2>(), |
| weight_decay, |
| momentum, |
| lr.data_ptr<float>(), |
| 1.0, |
| dampening, |
| nesterov, |
| maximize, |
| /* is_first_step */ false, |
| grad_scale_ptr, |
| found_inf_ptr); |
| }); |
| } |
| |
| } // namespace at::native |