Revert "[jiterator] Reduce templating in jitted_gpu_kernel_impl (#80103)"
This reverts commit df665b1a9d7506ee147524d02fc4639a3b27c56f.
Reverted https://github.com/pytorch/pytorch/pull/80103 on behalf of https://github.com/mehtanirav due to internal breakage
diff --git a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
index 81ce416..a235d00 100644
--- a/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
+++ b/aten/src/ATen/native/cuda/CUDAJitLoops.cuh
@@ -17,7 +17,6 @@
#include <c10/macros/Macros.h>
#include <c10/core/ScalarType.h>
-#include <c10/util/SmallBuffer.h>
#include <type_traits>
#include <tuple>
@@ -26,6 +25,8 @@
namespace at {
namespace native {
+namespace {
+
template <typename Tuple, std::size_t... I>
constexpr auto tuple_to_array_helper(Tuple& t, std::index_sequence<I...> seq) {
constexpr auto size = seq.size();
@@ -44,90 +45,111 @@
return tuple_to_array_helper(extra_args, std::make_index_sequence<tuple_size>{});
}
-struct JittedVecKernelCache {
- // Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
- at::cuda::jit::NvrtcFunction vec1;
- at::cuda::jit::NvrtcFunction vec2;
- at::cuda::jit::NvrtcFunction vec4;
-};
-
-struct JittedKernelVariantCache {
- JittedVecKernelCache vec;
- at::cuda::jit::NvrtcFunction noncontiguous;
- at::cuda::jit::NvrtcFunction dynamic_contiguous;
- at::cuda::jit::NvrtcFunction dynamic_noncontiguous;
-};
-
+// Helper function to return a vector<string>
+// corresponding to the type of the arguments in parameter pack.
template <typename... Args>
-c10::SmallBuffer<void*, 64> pack_kernel_args(c10::ArrayRef<void*> extra_args, Args... args) {
- std::array<void*, sizeof...(Args)> args_array({static_cast<void*>(args)...});
- c10::SmallBuffer<void*, 64> ret(args_array.size() + extra_args.size());
- std::copy_n(args_array.data(), args_array.size(), ret.data());
- std::copy_n(extra_args.data(), extra_args.size(), ret.data() + args_array.size());
- return ret;
+c10::SmallVector<std::string> get_extra_args_typenames() {
+ return {at::cuda::jit::typeName<Args>()...};
}
-template<typename array_t,
+} // namespace
+
+template<char const *name,
+ typename result_type,
+ typename f_inputs_type,
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
+ typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
- typename storer_t>
-void launch_jitted_unrolled_kernel(
- std::mutex &jiterator_mutex,
- at::cuda::jit::NvrtcFunction &fn_cache,
- const at::cuda::jit::KernelDescriptor &desc,
- int64_t N,
- array_t data,
- inp_calc_t ic,
- out_calc_t oc,
- loader_t l,
- storer_t s,
- bool contiguous,
- at::cuda::jit::BinaryFuncVariant scalar_pos,
- void* scalar_val,
- c10::ArrayRef<void*> extra_args) {
+ typename storer_t,
+ typename ... Args>
+static inline void launch_jitted_unrolled_kernel(
+ DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
+ inp_calc_t ic, out_calc_t oc, loader_t l, storer_t s, bool contiguous,
+ at::opmath_type<f_inputs_type> scalar_val,
+ std::tuple<Args...> extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
//casting result to int is always safe, intermediate is int64 and won't overflow
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
- if (!fn_cache.function) {
- const std::lock_guard<std::mutex> lock{jiterator_mutex};
- if (!fn_cache.function) {
+ static std::mutex _jiterator_mutex;
+ static std::vector<at::cuda::jit::NvrtcFunction> fns(c10::cuda::device_count());
+
+ at::cuda::jit::NvrtcFunction* fn_ptr = &fns[dev_idx];
+ if (!fn_ptr->function) {
+ const std::lock_guard<std::mutex> lock{_jiterator_mutex};
+ if (!fn_ptr->function) {
+ constexpr int nInputs = array_t::size() - 1;
+ constexpr int nOutputs = 1; // fix me
constexpr bool dynamic_casting = !std::is_same<decltype(l), memory::LoadWithoutCast>() ||
!std::is_same<decltype(s), memory::StoreWithoutCast>();
- auto code = at::cuda::jit::generate_code(
- desc, contiguous, dynamic_casting, scalar_pos);
- fn_cache = at::cuda::jit::jit_pwise_function(code, desc.name);
+ std::string string_name{name};
+ std::string f_inputs_type_str = at::cuda::jit::typeName<f_inputs_type>();
+ std::string compute_type_str = at::cuda::jit::typeName<at::opmath_type<f_inputs_type>>();
+ std::string result_type_str = at::cuda::jit::typeName<result_type>();
+ c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames<Args...>();
+ auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, string_name,
+ f_inputs_type_str, compute_type_str, result_type_str,
+ contiguous, dynamic_casting, scalar_pos, extra_args_types);
+ *fn_ptr = at::cuda::jit::jit_pwise_function(code, name);
}
}
- auto args = pack_kernel_args(extra_args, &N, &data, &ic, &oc, &l, &s, scalar_val);
- at::cuda::jit::launch_jitted_pwise_function(fn_cache, args.data(), {grid, 1u, 1u},
+ // pack args for kernel launch
+ constexpr int kernel_args = 7;
+ // size of `extra_args` is known at compile-time
+ constexpr auto extra_args_size = sizeof...(Args);
+ void* args[kernel_args + extra_args_size];
+ args[0] = static_cast<void*>(&N);
+ args[1] = static_cast<void*>(&data);
+ args[2] = static_cast<void*>(&ic);
+ args[3] = static_cast<void*>(&oc);
+ args[4] = static_cast<void*>(&l);
+ args[5] = static_cast<void*>(&s);
+ args[6] = static_cast<void*>(&scalar_val);
+
+ auto extra_args_array = tuple_to_array(extra_args);
+ for (const auto i : c10::irange(extra_args_size)) {
+ // since 7 slots are already filled in `args`
+ args[i + 7] = extra_args_array[i];
+ }
+ at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1u, 1u},
{num_threads(), 1u, 1u});
}
-template<int arity, typename array_t>
-void launch_jitted_vectorized_kernel(
- std::mutex &jiterator_mutex, JittedVecKernelCache &fn_cache,
- const at::cuda::jit::KernelDescriptor &desc, int64_t N, array_t data,
- at::cuda::jit::BinaryFuncVariant scalar_pos,
- void *scalar_val, c10::ArrayRef<void*> extra_args) {
+template<
+ char const *name,
+ typename result_type,
+ typename f_inputs_type,
+ int arity,
+ at::cuda::jit::BinaryFuncVariant scalar_pos,
+ typename array_t, typename ... Args>
+static inline void launch_jitted_vectorized_kernel(DeviceIndex dev_idx, int64_t N, const std::string& f, array_t data,
+at::opmath_type<f_inputs_type> scalar_val, std::tuple<Args...> extra_args) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
// N is still int64_t for the computation, but it's always safe to cast result to int
const uint32_t grid = (N + block_work_size() - 1) / block_work_size();
- const int vec_size = at::cuda::jit::can_vectorize_up_to(
- desc, c10::ArrayRef<char*>(data.data, data.size()));
+ const int vec_size = memory::jitted_can_vectorize_up_to<result_type, f_inputs_type, arity>(data);
// Different kernels are compiled depending on what we're vectorizing up to (1, 2 or 4 elements)
// fn_ptr is set to the appropriate function based on the vec size and GPU used
+ // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
+ // the same compute capability
+ static std::mutex _jiterator_mutex;
+ static std::vector<at::cuda::jit::NvrtcFunction> fns4(c10::cuda::device_count());
+ static std::vector<at::cuda::jit::NvrtcFunction> fns2(c10::cuda::device_count());
+ static std::vector<at::cuda::jit::NvrtcFunction> fns1(c10::cuda::device_count());
+
+
at::cuda::jit::NvrtcFunction* fn_ptr;
if (vec_size == 4) {
- fn_ptr = &fn_cache.vec4;
+ fn_ptr = &fns4[dev_idx];
} else if (vec_size == 2) {
- fn_ptr = &fn_cache.vec2;
+ fn_ptr = &fns2[dev_idx];
} else if (vec_size ==1) {
- fn_ptr = &fn_cache.vec1;
+ fn_ptr = &fns1[dev_idx];
} else {
TORCH_INTERNAL_ASSERT(false, "unexpected vec_size for jitter vectorized kernel");
}
@@ -135,54 +157,94 @@
bool vectorized = vec_size > 1;
if (!fn_ptr->function) {
- const std::lock_guard<std::mutex> lock{jiterator_mutex};
+ const std::lock_guard<std::mutex> lock{_jiterator_mutex};
if (!fn_ptr->function) { // cache miss!
// Generates program
- auto code = at::cuda::jit::generate_code(
- desc, /*contiguous=*/true, /*dynamic_casting=*/false,
- scalar_pos, vectorized, vec_size);
- std::string kernel_name = vectorized ? desc.name + "_vectorized" + std::to_string(vec_size) : desc.name;
+ constexpr int nInputs = array_t::size() - 1;
+ constexpr int nOutputs = 1; // fix me
+ std::string string_name{name};
+ std::string f_inputs_type_str = at::cuda::jit::typeName<f_inputs_type>();
+ std::string compute_type_str = at::cuda::jit::typeName<at::opmath_type<f_inputs_type>>();
+ std::string result_type_str = at::cuda::jit::typeName<result_type>();
+ c10::SmallVector<std::string> extra_args_types = get_extra_args_typenames<Args...>();
+ auto code = at::cuda::jit::generate_code(nInputs, nOutputs, f, string_name,
+ f_inputs_type_str, compute_type_str, result_type_str,
+ /*contiguous=*/true, /*dynamic_casting=*/false,
+ scalar_pos,
+ extra_args_types,
+ vectorized, vec_size);
+ std::string kernel_name = vectorized ? string_name + "_vectorized" + std::to_string(vec_size) : string_name;
// Acquires the program
*fn_ptr = at::cuda::jit::jit_pwise_function(code, kernel_name);
}
}
+ // size of `extra_args` is known at compile-time
+ constexpr auto extra_args_size = sizeof...(Args);
+ auto extra_args_array = tuple_to_array(extra_args);
+
if (vectorized) {
- auto args = pack_kernel_args(extra_args, &N, &data, scalar_val);
- at::cuda::jit::launch_jitted_pwise_function(
- *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
+ // pack args for kernel launch
+ constexpr int kernel_args = 3;
+ void* args[kernel_args + extra_args_size];
+ args[0] = static_cast<void*>(&N);
+ args[1] = static_cast<void*>(&data);
+ args[2] = static_cast<void*>(&scalar_val);
+
+ for (const auto i : c10::irange(extra_args_size)) {
+ // since 3 slots are already filled in `args`
+ args[i + 3] = extra_args_array[i];
+ }
+ at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1u, 1u}, {num_threads(), 1u, 1u});
} else {
auto ic = TrivialOffsetCalculator<arity>();
auto oc = TrivialOffsetCalculator<1>();
auto l = memory::LoadWithoutCast();
auto s = memory::StoreWithoutCast();
- auto args = pack_kernel_args(
- extra_args, &N, &data, &ic, &oc, &l, &s, scalar_val);
- at::cuda::jit::launch_jitted_pwise_function(
- *fn_ptr, args.data(), {grid, 1u, 1u}, {num_threads(), 1u, 1u});
+ // pack args for kernel launch
+ constexpr int kernel_args = 7;
+ void* args[kernel_args + extra_args_size];
+ args[0] = static_cast<void*>(&N);
+ args[1] = static_cast<void*>(&data);
+ args[2] = static_cast<void*>(&ic);
+ args[3] = static_cast<void*>(&oc);
+ args[4] = static_cast<void*>(&l);
+ args[5] = static_cast<void*>(&s);
+ args[6] = static_cast<void*>(&scalar_val);
+
+ for (const auto i : c10::irange(extra_args_size)) {
+ // since 7 slots are already filled in `args`
+ args[i + 7] = extra_args_array[i];
+ }
+
+ at::cuda::jit::launch_jitted_pwise_function(*fn_ptr, args, {grid, 1u, 1u}, {num_threads(), 1u, 1u});
}
}
-template <int arity>
-void jitted_gpu_kernel_generic(
- std::mutex &jiterator_mutex,
- JittedKernelVariantCache &cache,
- const at::cuda::jit::KernelDescriptor &desc,
- at::cuda::jit::BinaryFuncVariant scalar_pos,
- c10::ArrayRef<void*> extra_args,
+template <
+ char const* name,
+ typename result_type,
+ typename f_inputs_type,
+ int arity,
+ at::cuda::jit::BinaryFuncVariant scalar_pos =
+ at::cuda::jit::BinaryFuncVariant::NoScalar,
+ typename... Args>
+void jitted_gpu_kernel_impl(
TensorIteratorBase& iter,
+ const std::string& f,
const bool dynamic_casting,
- void *scalar_val) {
+ at::opmath_type<f_inputs_type> scalar_val,
+ std::tuple<Args...> extra_args) {
TORCH_INTERNAL_ASSERT(iter.can_use_32bit_indexing());
TORCH_INTERNAL_ASSERT(iter.ninputs() == arity);
TORCH_INTERNAL_ASSERT(iter.noutputs() == 1);
constexpr int ntensors = arity + 1;
at::detail::Array<char*, ntensors> data;
- for (auto i : c10::irange(ntensors)) {
+ for (auto i = decltype(ntensors){0}; i < ntensors; ++i) {
data[i] = (char*)iter.data_ptr(i);
}
@@ -200,9 +262,8 @@
if (!dynamic_casting) {
if (contiguous) {
// Case 1: no dynamic casting and contiguous
- launch_jitted_vectorized_kernel<arity>(
- jiterator_mutex, cache.vec, desc,
- numel, data, scalar_pos, scalar_val, extra_args);
+ launch_jitted_vectorized_kernel<name, result_type, f_inputs_type, arity, scalar_pos>(
+ iter.device().index(), numel, f, data, scalar_val, extra_args);
return;
}
@@ -211,10 +272,9 @@
auto output_offset_calculator = make_output_offset_calculator(iter);
auto loader = memory::LoadWithoutCast();
auto storer = memory::StoreWithoutCast();
- launch_jitted_unrolled_kernel(
- jiterator_mutex, cache.noncontiguous, desc, numel, data,
- input_offset_calculator, output_offset_calculator, loader,
- storer, contiguous, scalar_pos, scalar_val, extra_args);
+ launch_jitted_unrolled_kernel<name, result_type, f_inputs_type, scalar_pos>(
+ iter.device().index(), numel, f, data, input_offset_calculator,
+ output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args);
return;
}
@@ -231,55 +291,18 @@
// Case 3: dynamic casting and contiguous
auto input_offset_calculator = TrivialOffsetCalculator<arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
- launch_jitted_unrolled_kernel(
- jiterator_mutex, cache.dynamic_contiguous, desc, numel, data, input_offset_calculator,
- output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
+ launch_jitted_unrolled_kernel<name, result_type, f_inputs_type, scalar_pos>(
+ iter.device().index(), numel, f, data, input_offset_calculator,
+ output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args);
return;
}
// Case 4: dynamic casting and noncontiguous
auto input_offset_calculator = make_input_offset_calculator<arity>(iter);
auto output_offset_calculator = make_output_offset_calculator(iter);
- launch_jitted_unrolled_kernel(
- jiterator_mutex, cache.dynamic_noncontiguous, desc, numel, data, input_offset_calculator,
- output_offset_calculator, loader, storer, contiguous, scalar_pos, scalar_val, extra_args);
-}
-
-// NOTE: static to reduce chances of name collision.
-template <
- char const* name,
- typename result_type,
- typename f_inputs_type,
- int arity,
- at::cuda::jit::BinaryFuncVariant scalar_pos =
- at::cuda::jit::BinaryFuncVariant::NoScalar,
- typename... ExtraArgs>
-static void jitted_gpu_kernel_impl(
- TensorIteratorBase& iter,
- const std::string &f,
- const bool dynamic_casting,
- at::opmath_type<f_inputs_type> scalar_val,
- std::tuple<ExtraArgs...> extra_args) {
-
- // TODO: Memory use can probably be optimized by re-using kernels across GPUs with
- // the same compute capability
- static std::mutex jiterator_mutex;
- static std::vector<JittedKernelVariantCache> device_caches(c10::cuda::device_count());
- static const auto desc = at::cuda::jit::make_kernel_descriptor<
- result_type, f_inputs_type, arity, ExtraArgs...>(name, f);
-
- auto &cache = device_caches[iter.device().index()];
- auto extra_args_array = tuple_to_array(extra_args);
- return jitted_gpu_kernel_generic<arity>(
- jiterator_mutex,
- cache,
- desc,
- scalar_pos,
- extra_args_array,
- iter,
- dynamic_casting,
- &scalar_val
- );
+ launch_jitted_unrolled_kernel<name, result_type, f_inputs_type, scalar_pos>(
+ iter.device().index(), numel, f, data, input_offset_calculator,
+ output_offset_calculator, loader, storer, contiguous, scalar_val, extra_args);
}
}} // at::native
diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh
index 355db34..409354b 100644
--- a/aten/src/ATen/native/cuda/MemoryAccess.cuh
+++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh
@@ -382,4 +382,19 @@
return result;
}
+// jitted version of the above
+// See Note [Jiterator], this relies on the assumptions enumerated there
+template<typename result_type, typename common_type, int arity, typename array_t>
+inline int jitted_can_vectorize_up_to(array_t pointers) {
+ // Deals with output
+ int result = can_vectorize_up_to<result_type>(pointers[0]);
+
+ // Incorporates input(s)
+ for (auto i = decltype(arity){1}; i < (arity + 1); ++i) {
+ result = std::min<int>(result, can_vectorize_up_to<common_type>(pointers[i]));
+ }
+
+ return result;
+}
+
}}} // namespace at::native::memory
diff --git a/aten/src/ATen/native/cuda/jit_utils.cpp b/aten/src/ATen/native/cuda/jit_utils.cpp
index 8decd58..10270bd 100644
--- a/aten/src/ATen/native/cuda/jit_utils.cpp
+++ b/aten/src/ATen/native/cuda/jit_utils.cpp
@@ -8,7 +8,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
#include <ATen/code_template.h>
-#include <ATen/OpMathType.h>
#include <ATen/native/cuda/jit_utils.h>
#include <ATen/cuda/llvm_jit_strings.h>
#include <ATen/native/cuda/reduction_template.cuh>
@@ -727,36 +726,6 @@
}
}
-std::string generate_code(
- const KernelDescriptor &desc,
- bool contiguous,
- bool dynamic_casting,
- BinaryFuncVariant scalar_pos,
- bool vectorized,
- int vec_size,
- bool return_by_ref) {
- c10::SmallVector<std::string> extra_args_typenames(desc.extra_args_types.size());
- for (auto i : c10::irange(extra_args_typenames.size())) {
- extra_args_typenames[i] = typeName(desc.extra_args_types[i]);
- }
-
- return generate_code(
- desc.nInputs,
- desc.nOutputs,
- desc.f,
- desc.name,
- typeName(desc.f_inputs_type),
- typeName(toOpMathType(desc.f_inputs_type)),
- typeName(desc.result_type),
- contiguous,
- dynamic_casting,
- scalar_pos,
- extra_args_typenames,
- vectorized,
- vec_size,
- return_by_ref);
-}
-
//FIXME - this are defined in Loops.cuh, but including Loops.cuh here would lead to circular includes Loops.cuh -> CUDALoops.cuh -> jit_utils.h -> Loops.cuh
#define THREAD_WORK_SIZE 4
constexpr int thread_work_size = THREAD_WORK_SIZE;
diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h
index 9586ecf..129ad3e 100644
--- a/aten/src/ATen/native/cuda/jit_utils.h
+++ b/aten/src/ATen/native/cuda/jit_utils.h
@@ -19,67 +19,6 @@
CUfunction function = nullptr;
};
-struct KernelDescriptor {
- std::string name;
- std::string f;
- c10::ScalarType f_inputs_type;
- c10::ScalarType result_type;
- c10::SmallVector<c10::ScalarType> extra_args_types;
- int nInputs, nOutputs;
-};
-
-// Helper function to return a vector<string>
-// corresponding to the type of the arguments in parameter pack.
-template <typename... Args>
-c10::SmallVector<at::ScalarType> get_extra_args_types() {
- return {c10::CppTypeToScalarType<Args>::value ...};
-}
-
-template <
- typename result_type,
- typename f_inputs_type,
- int arity,
- typename... ExtraArgs>
-KernelDescriptor make_kernel_descriptor(std::string name, std::string f) {
- KernelDescriptor ret;
- ret.name = std::move(name);
- ret.f = std::move(f);
- ret.f_inputs_type = c10::CppTypeToScalarType<f_inputs_type>::value;
- ret.result_type = c10::CppTypeToScalarType<result_type>::value;
- ret.extra_args_types = get_extra_args_types<ExtraArgs...>();
- ret.nInputs = arity;
- ret.nOutputs = 1; // TODO: Support more than 1 output
- return ret;
-}
-
-inline int can_vectorize_up_to(size_t default_alignment, void *pointer) {
- auto ip = reinterpret_cast<uintptr_t>(pointer);
- if (ip % (4 * default_alignment) == 0) {
- return 4;
- }
- if (ip % (2 * default_alignment) == 0) {
- return 2;
- }
- return 1;
-}
-
-inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef<char*> pointers) {
- TORCH_INTERNAL_ASSERT(desc.nOutputs == 1);
- TORCH_INTERNAL_ASSERT(static_cast<int64_t>(pointers.size()) == 1 + desc.nInputs);
-
- // Deals with output
- auto result_size = c10::scalarTypeToTypeMeta(desc.result_type).itemsize();
- int result = can_vectorize_up_to(result_size, pointers[0]);
-
- // Incorporates input(s)
- auto input_size = c10::scalarTypeToTypeMeta(desc.f_inputs_type).itemsize();
- for (auto i : c10::irange(1, pointers.size())) {
- result = std::min(result, can_vectorize_up_to(input_size, pointers[i]));
- }
-
- return result;
-}
-
std::string generate_code(
int nInputs,
int nOutputs,
@@ -96,15 +35,6 @@
int vec_size=0,
bool return_by_ref=false);
-std::string generate_code(
- const KernelDescriptor &desc,
- bool contiguous,
- bool dynamic_casting,
- BinaryFuncVariant scalar_pos,
- bool vectorized=false,
- int vec_size=0,
- bool return_by_ref=false);
-
std::string generate_reduction_code(
int nOutputs,
const std::string& func,
@@ -178,12 +108,17 @@
}
#define TYPE_NAME_CASE(ctype, scalartype) \
- case ScalarType::scalartype: return typeName<ctype>();
+ case ScalarType::scalartype: return std::string(#ctype);
inline std::string typeName(ScalarType t) {
switch (t) {
- AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(TYPE_NAME_CASE)
- default:
- TORCH_CHECK(false, "invalid type for jiterator");
+ AT_FORALL_SCALAR_TYPES(TYPE_NAME_CASE)
+ case ScalarType::Bool : return "bool";
+ case ScalarType::Half : return "at::Half";
+ case ScalarType::BFloat16 : return "at::BFloat16";
+ case ScalarType::ComplexFloat : return "std::complex<float>";
+ case ScalarType::ComplexDouble : return "std::complex<double>";
+ default:
+ TORCH_CHECK(false, "invalid type for jiterator");
}
}
#undef TYPE_NAME_CASE