| #pragma once |
| |
| #include <c10/core/DeviceType.h> |
| #include <c10/macros/Export.h> |
| |
| #include <atomic> |
| #include <utility> |
| |
| // Implements instruction set specific function dispatch. |
| // |
| // Kernels that may make use of specialized instruction sets (e.g. AVX2) are |
| // compiled multiple times with different compiler flags (e.g. -mavx2). A |
| // DispatchStub contains a table of function pointers for a kernel. At runtime, |
| // the fastest available kernel is chosen based on the features reported by |
| // cpuinfo. |
| // |
| // Example: |
| // |
| // In native/MyKernel.h: |
| // using fn_type = void(*)(const Tensor& x); |
| // DECLARE_DISPATCH(fn_type, stub); |
| // |
| // In native/MyKernel.cpp |
| // DEFINE_DISPATCH(stub); |
| // |
| // In native/cpu/MyKernel.cpp: |
| // namespace { |
| // // use anonymous namespace so that different cpu versions won't conflict |
| // void kernel(const Tensor& x) { ... } |
| // } |
| // REGISTER_DISPATCH(stub, &kernel); |
| // |
| // To call: |
| // stub(kCPU, tensor); |
| // |
| // TODO: CPU instruction set selection should be folded into whatever |
| // the main dispatch mechanism is. |
| |
| // ignore warnings about DispatchStub::DEFAULT, AVX, AVX2 defined elsewhere |
| #if defined(__clang__) |
| #pragma clang diagnostic push |
| #pragma clang diagnostic ignored "-Wundefined-var-template" |
| #endif |
| |
| namespace at { namespace native { |
| |
| enum class CPUCapability { |
| DEFAULT = 0, |
| #if defined(HAVE_VSX_CPU_DEFINITION) |
| VSX = 1, |
| #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) |
| ZVECTOR = 1, |
| #else |
| AVX2 = 1, |
| AVX512 = 2, |
| #endif |
| NUM_OPTIONS |
| }; |
| |
| CPUCapability get_cpu_capability(); |
| |
| template <typename FnPtr, typename T> |
| struct DispatchStub; |
| |
| /** |
| * The sole purpose of this class is to outline methods that don't need to be |
| * specialized or otherwise inlined and duplicated (by the compiler due to |
| * template expansion), since it causes size bloat if there are a significant |
| * number of specialization of the DispatchStub<> class. |
| */ |
| struct TORCH_API DispatchStubImpl { |
| void* get_call_ptr( |
| DeviceType device_type |
| , void *DEFAULT |
| #ifdef HAVE_AVX512_CPU_DEFINITION |
| , void *AVX512 |
| #endif |
| #ifdef HAVE_AVX2_CPU_DEFINITION |
| , void *AVX2 |
| #endif |
| #ifdef HAVE_VSX_CPU_DEFINITION |
| , void *VSX |
| #endif |
| #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
| , void *ZVECTOR |
| #endif |
| ); |
| |
| /** |
| * The CPU Dispatch actual method is chosen in decreasing order of preference by |
| * DispatchStubImpl::choose_cpu_impl() in case none is found by |
| * DispatchStubImpl::get_call_ptr() in cpu_dispatch_ptr. |
| */ |
| void* choose_cpu_impl( |
| void *DEFAULT |
| #ifdef HAVE_AVX512_CPU_DEFINITION |
| , void *AVX512 |
| #endif |
| #ifdef HAVE_AVX2_CPU_DEFINITION |
| , void *AVX2 |
| #endif |
| #ifdef HAVE_VSX_CPU_DEFINITION |
| , void *VSX |
| #endif |
| #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
| , void *ZVECTOR |
| #endif |
| ); |
| |
| // Fixing dispatch error in Windows debug builds. |
| // See https://github.com/pytorch/pytorch/issues/22681 for more details. |
| #if defined(_MSC_VER) && defined(_DEBUG) |
| std::atomic<void*> cpu_dispatch_ptr; |
| void* cuda_dispatch_ptr; |
| void* hip_dispatch_ptr; |
| void* mps_dispatch_ptr; |
| #else |
| std::atomic<void*> cpu_dispatch_ptr{nullptr}; |
| void* cuda_dispatch_ptr = nullptr; |
| void* hip_dispatch_ptr = nullptr; |
| void* mps_dispatch_ptr = nullptr; |
| #endif |
| }; |
| |
| template <typename rT, typename T, typename... Args> |
| struct DispatchStub<rT (*)(Args...), T> { |
| using FnPtr = rT (*) (Args...); |
| |
| DispatchStub() = default; |
| DispatchStub(const DispatchStub&) = delete; |
| DispatchStub& operator=(const DispatchStub&) = delete; |
| |
| private: |
| FnPtr get_call_ptr(DeviceType device_type) { |
| return reinterpret_cast<FnPtr>( |
| impl.get_call_ptr(device_type |
| , reinterpret_cast<void*>(DEFAULT) |
| #ifdef HAVE_AVX512_CPU_DEFINITION |
| , reinterpret_cast<void*>(AVX512) |
| #endif |
| #ifdef HAVE_AVX2_CPU_DEFINITION |
| , reinterpret_cast<void*>(AVX2) |
| #endif |
| #ifdef HAVE_VSX_CPU_DEFINITION |
| , reinterpret_cast<void*>(VSX) |
| #endif |
| #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
| , reinterpret_cast<void*>(ZVECTOR) |
| #endif |
| ) |
| ); |
| } |
| |
| public: |
| template <typename... ArgTypes> |
| rT operator()(DeviceType device_type, ArgTypes&&... args) { |
| FnPtr call_ptr = get_call_ptr(device_type); |
| return (*call_ptr)(std::forward<ArgTypes>(args)...); |
| } |
| |
| void set_cuda_dispatch_ptr(FnPtr fn_ptr) { |
| impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
| } |
| |
| void set_hip_dispatch_ptr(FnPtr fn_ptr) { |
| impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
| } |
| |
| void set_mps_dispatch_ptr(FnPtr fn_ptr) { |
| impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr); |
| } |
| |
| static TORCH_API FnPtr DEFAULT; |
| #ifdef HAVE_AVX512_CPU_DEFINITION |
| static TORCH_API FnPtr AVX512; |
| #endif |
| #ifdef HAVE_AVX2_CPU_DEFINITION |
| static TORCH_API FnPtr AVX2; |
| #endif |
| #ifdef HAVE_VSX_CPU_DEFINITION |
| static TORCH_API FnPtr VSX; |
| #endif |
| #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
| static TORCH_API FnPtr ZVECTOR; |
| #endif |
| private: |
| DispatchStubImpl impl; |
| }; |
| |
| namespace { |
| template <typename DispatchStub> |
| struct RegisterCUDADispatch { |
| RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
| stub.set_cuda_dispatch_ptr(value); |
| } |
| }; |
| |
| template <typename DispatchStub> |
| struct RegisterMPSDispatch { |
| RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
| stub.set_mps_dispatch_ptr(value); |
| } |
| }; |
| |
| template <typename DispatchStub> |
| struct RegisterHIPDispatch { |
| RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) { |
| // TODO: make this point at hip_dispatch_ptr |
| stub.set_cuda_dispatch_ptr(value); |
| } |
| }; |
| |
| } // anonymous namespace |
| // Compiler will complain if you put things like std::tuple<Tensor, Tensor> in |
| // the `fn` argument of DECLARE_DISPATCH. Some possible workarounds, e.g., |
| // adding parentheses and using helper struct to get rid of the parentheses, do |
| // not work with MSVC. So do a `using`-declaration if you need to pass in such |
| // `fn`, e.g., grid_sampler_2d_backward_cpu_kernel in GridSampleKernel.h. |
| #define DECLARE_DISPATCH(fn, name) \ |
| struct name : DispatchStub<fn, name> { \ |
| name() = default; \ |
| name(const name&) = delete; \ |
| name& operator=(const name&) = delete; \ |
| }; \ |
| extern TORCH_API struct name name |
| |
| #define DEFINE_DISPATCH(name) struct name name |
| |
| #define REGISTER_ARCH_DISPATCH(name, arch, fn) \ |
| template <> name::FnPtr TORCH_API DispatchStub<name::FnPtr, struct name>::arch = fn; |
| |
| #ifdef HAVE_AVX512_CPU_DEFINITION |
| #define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn) |
| #else |
| #define REGISTER_AVX512_DISPATCH(name, fn) |
| #endif |
| |
| #ifdef HAVE_AVX2_CPU_DEFINITION |
| #define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn) |
| #else |
| #define REGISTER_AVX2_DISPATCH(name, fn) |
| #endif |
| |
| #ifdef HAVE_VSX_CPU_DEFINITION |
| #define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn) |
| #else |
| #define REGISTER_VSX_DISPATCH(name, fn) |
| #endif |
| |
| #ifdef HAVE_ZVECTOR_CPU_DEFINITION |
| #define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn) |
| #else |
| #define REGISTER_ZVECTOR_DISPATCH(name, fn) |
| #endif |
| |
| // Macro to register the same kernel for all CPU arch types. This is useful |
| // if a kernel does not benefit from being recompiled across different arch types. |
| #define REGISTER_ALL_CPU_DISPATCH(name, fn) \ |
| REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \ |
| REGISTER_AVX512_DISPATCH(name, fn) \ |
| REGISTER_AVX2_DISPATCH(name, fn) \ |
| REGISTER_VSX_DISPATCH(name, fn) \ |
| REGISTER_ZVECTOR_DISPATCH(name, fn) |
| |
| #define REGISTER_NO_CPU_DISPATCH(name) \ |
| REGISTER_ALL_CPU_DISPATCH(name, nullptr) |
| |
| #define REGISTER_CUDA_DISPATCH(name, fn) \ |
| static RegisterCUDADispatch<struct name> name ## __register(name, fn); |
| |
| #define REGISTER_HIP_DISPATCH(name, fn) \ |
| static RegisterHIPDispatch<struct name> name ## __register(name, fn); |
| |
| #define REGISTER_MPS_DISPATCH(name, fn) \ |
| static RegisterMPSDispatch<struct name> name ## __register(name, fn); |
| |
| // NB: This macro must be used in an actual 'cu' file; if you try using |
| // it from a 'cpp' file it will not work! |
| #if defined(__CUDACC__) |
| #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) |
| #elif defined(__HIPCC__) |
| // TODO: cut this over to HIP dispatch once we stop pretending that CUDA |
| // is HIP in the PyTorch HIPify build. |
| #define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn) |
| // #define REGISTER_DISPATCH(name, fn) REGISTER_HIP_DISPATCH(name, fn) |
| #elif defined(__OBJC__) && defined(USE_MPS) |
| // NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel |
| #define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn) |
| #elif defined(CPU_CAPABILITY) |
| #define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn) |
| #define REGISTER_NO_AVX512_DISPATCH(name) \ |
| REGISTER_AVX512_DISPATCH(name, nullptr) |
| #endif |
| |
| |
| }} // namespace at::native |
| |
| |
| #if defined(__clang__) |
| #pragma clang diagnostic pop |
| #endif |