| #include <torch/csrc/jit/frontend/function_schema_parser.h> |
| #include <torch/csrc/utils/python_dispatch.h> |
| |
| #include <ATen/ATen.h> |
| #include <ATen/FuncTorchTLS.h> |
| #include <ATen/FunctionalTensorWrapper.h> |
| #include <ATen/TensorSubclassLikeUtils.h> |
| #include <ATen/core/NestedIntSymNodeImpl.h> |
| #include <ATen/core/PythonOpRegistrationTrampoline.h> |
| #include <ATen/core/dispatch/Dispatcher.h> |
| |
| #include <ATen/functorch/BatchedTensorImpl.h> |
| #include <torch/library.h> |
| |
| #include <c10/core/SafePyObject.h> |
| #include <torch/csrc/PyInterpreter.h> |
| #include <torch/csrc/autograd/python_variable.h> |
| #include <torch/csrc/jit/python/pybind_utils.h> |
| #include <torch/csrc/utils/tensor_new.h> |
| |
| #include <c10/util/flat_hash_map.h> |
| #include <pybind11/operators.h> |
| #include <pybind11/stl.h> |
| #include <torch/csrc/inductor/aoti_eager/kernel_holder.h> |
| #include <torch/csrc/utils/pybind.h> |
| #include <torch/csrc/utils/python_raii.h> |
| |
| #include <iostream> |
| #include <utility> |
| |
| namespace py = pybind11; |
| |
| namespace torch::impl::dispatch { |
| |
| // NB: I'd like to index this on OperatorHandle, but I can't, as I can't |
| // guarantee that the main interpreter has finish doing all registrations before |
| // the other interpreters start banging on it |
| static ska::flat_hash_map< |
| c10::OperatorName, |
| ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>> |
| python_registrations_; |
| |
| static torch::Library::Kind parseKind(const std::string& k) { |
| static std::unordered_map<std::string, torch::Library::Kind> kind_map = { |
| {"DEF", torch::Library::DEF}, |
| {"IMPL", torch::Library::IMPL}, |
| {"FRAGMENT", torch::Library::FRAGMENT}, |
| }; |
| auto it = kind_map.find(k); |
| TORCH_CHECK(it != kind_map.end(), "could not parse ", k); |
| return it->second; |
| } |
| static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) { |
| static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = { |
| {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE}, |
| {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA}, |
| {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION}, |
| {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default |
| }; |
| auto it = key_map.find(k); |
| TORCH_CHECK(it != key_map.end(), "could not parse ", k); |
| return it->second; |
| } |
| |
| template <typename Func> |
| inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) { |
| auto mb_key = std::string(key).empty() |
| ? c10::nullopt |
| : c10::make_optional(c10::parseDispatchKey(key)); |
| if (mb_key) { |
| return torch::dispatch(*mb_key, std::forward<Func>(raw_f)); |
| } else { |
| torch::CppFunction f(std::forward<Func>(raw_f)); |
| return f; |
| } |
| } |
| |
| struct EnableHermeticPyObject { |
| EnableHermeticPyObject() |
| : old_(c10::impl::HermeticPyObjectTLS::get_state()), |
| old_excluded_python_( |
| c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)), |
| old_python_( |
| c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)), |
| old_python_snapshot_(c10::impl::tls_is_dispatch_key_included( |
| at::DispatchKey::PythonTLSSnapshot)) { |
| c10::impl::HermeticPyObjectTLS::set_state(true); |
| c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true); |
| c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false); |
| c10::impl::tls_set_dispatch_key_included( |
| at::DispatchKey::PythonTLSSnapshot, false); |
| } |
| ~EnableHermeticPyObject() { |
| c10::impl::HermeticPyObjectTLS::set_state(old_); |
| c10::impl::tls_set_dispatch_key_excluded( |
| at::DispatchKey::Python, old_excluded_python_); |
| c10::impl::tls_set_dispatch_key_included( |
| at::DispatchKey::Python, old_python_); |
| c10::impl::tls_set_dispatch_key_included( |
| at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_); |
| } |
| bool old_; |
| bool old_excluded_python_; |
| bool old_python_; |
| bool old_python_snapshot_; |
| }; |
| |
| class PythonKernelHolder : public c10::OperatorKernel { |
| c10::SafePyObject func_; |
| c10::DispatchKey dispatch_key_; |
| // If "with_keyset", then we expect a keyset as the first arg. |
| bool with_keyset_; |
| |
| public: |
| PythonKernelHolder( |
| py::object func, |
| c10::DispatchKey dispatch_key, |
| bool with_keyset = false) |
| : func_(func.release().ptr(), getPyInterpreter()), |
| dispatch_key_(dispatch_key), |
| with_keyset_(with_keyset) {} |
| |
| void operator()( |
| const c10::OperatorHandle& op, |
| c10::DispatchKeySet keyset, |
| torch::jit::Stack* stack) { |
| // Figure out if we can handle it hermetically, or if we have |
| // to double dispatch |
| |
| // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch |
| const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); |
| if (mode_stack_len > 0) { |
| const auto& cur_torch_dispatch_mode_state = |
| c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); |
| cur_torch_dispatch_mode_state->pyinterpreter() |
| ->python_op_registration_trampoline( |
| op, dispatch_key_, keyset, stack, with_keyset_); |
| return; |
| } |
| |
| const auto& schema = op.schema(); |
| const auto num_arguments = schema.arguments().size(); |
| |
| // Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which |
| // means it's a nontrivial tensor subclass) |
| for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) { |
| if (ivalue.isTensor()) { |
| auto* interpreter = |
| ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter(); |
| if (interpreter && |
| ivalue.unsafeToTensorImpl()->key_set().has( |
| at::DispatchKey::Python)) { |
| (*interpreter) |
| ->python_op_registration_trampoline( |
| op, dispatch_key_, keyset, stack, with_keyset_); |
| return; |
| } |
| } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) { |
| // NB: use toListRef as it doesn't induce refcount bumps |
| // (toTensorListRef is not a thing) |
| for (const auto& nv : ivalue.toListRef()) { |
| if (nv.isNone()) { |
| continue; |
| } |
| auto* interpreter = |
| nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter(); |
| if (interpreter && |
| nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) { |
| (*interpreter) |
| ->python_op_registration_trampoline( |
| op, dispatch_key_, keyset, stack, with_keyset_); |
| return; |
| } |
| } |
| } |
| } |
| |
| // Nothing requires the operator to be homed to a specific interpreter, so |
| // run it on the current interpreter |
| |
| auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); |
| py::gil_scoped_acquire g; |
| // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic |
| // mode unconditionally in all situations when you're using multipy. |
| // Eventually just delete this entirely. (Note that you may break multipy |
| // anyway this way with dispatcher registered functions that require |
| // hermetic to be off.) |
| #if defined(USE_DEPLOY) |
| EnableHermeticPyObject g2; |
| #endif |
| auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
| auto func = |
| py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter())); |
| auto obj = with_keyset_ |
| ? func(keyset, *args_kwargs.first, **args_kwargs.second) |
| : func(*args_kwargs.first, **args_kwargs.second); |
| if (!obj) { |
| throw python_error(); |
| } |
| pushPyOutToStack(op, stack, obj, "PythonKernelHolder"); |
| } |
| }; |
| |
| static torch::_RegisterOrVerify register_or_verify() { |
| if (isMainPyInterpreter()) { |
| return torch::_RegisterOrVerify::REGISTER; |
| } else { |
| return torch::_RegisterOrVerify::VERIFY; |
| } |
| } |
| |
| static py::object ophandle_call_boxed( |
| const c10::OperatorHandle& handle, |
| py::args args, |
| const py::kwargs& kwargs) { |
| auto stack = torch::jit::createStackForSchema( |
| handle.schema(), |
| std::move(args), |
| kwargs, |
| /*self=*/c10::nullopt); |
| { |
| pybind11::gil_scoped_release no_gil_guard; |
| handle.callBoxed(stack); |
| } |
| return torch::jit::createPyObjectForStack(std::move(stack)); |
| } |
| |
| // A small RAII guard that lets you explicitly *remove* a key from the TLS |
| // exclude set. |
| class SetExcludeDispatchKeyGuard { |
| public: |
| SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded) |
| : k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) { |
| c10::impl::tls_set_dispatch_key_excluded(k, set_excluded); |
| } |
| ~SetExcludeDispatchKeyGuard() { |
| c10::impl::tls_set_dispatch_key_excluded(k, old); |
| } |
| SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete; |
| SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) = |
| delete; |
| SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete; |
| SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete; |
| |
| private: |
| at::DispatchKey k; |
| bool old; |
| }; |
| |
| void initDispatchBindings(PyObject* module) { |
| auto m = py::handle(module).cast<py::module>(); |
| |
| py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle") |
| .def("schema", &c10::OperatorHandle::schema) |
| .def("debug", &c10::OperatorHandle::debug) |
| .def( |
| "redispatch_boxed", |
| [](const py::object& self, |
| c10::DispatchKeySet keyset, |
| py::args args, |
| const py::kwargs& kwargs) { |
| auto& handle = self.cast<c10::OperatorHandle&>(); |
| auto stack = torch::jit::createStackForSchema( |
| handle.schema(), |
| std::move(args), |
| kwargs, |
| /*self=*/c10::nullopt); |
| { |
| pybind11::gil_scoped_release no_gil_guard; |
| handle.redispatchBoxed(keyset, &stack); |
| } |
| return torch::jit::createPyObjectForStack(std::move(stack)); |
| }); |
| |
| m.def("_dispatch_call_boxed", &ophandle_call_boxed); |
| |
| // TODO: figure out how to do chaining |
| py::class_<torch::Library>(m, "_DispatchModule") |
| .def( |
| "reset", |
| [](const py::object& self) { |
| TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); |
| self.cast<torch::Library&>().reset(); |
| return; |
| }, |
| "") |
| // Some of these APIs are only for testing and do not work in multipy |
| // environment |
| .def( |
| "def_", |
| [](py::object self, const char* schema, const char* alias) { |
| TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); |
| self.cast<torch::Library&>().def( |
| torch::schema(schema, parseAliasAnalysisKind(alias))); |
| return self; |
| }, |
| "", |
| py::arg("schema"), |
| py::arg("alias") = "") |
| // Simulated "legacy" def where alias analysis kind is not set. |
| // Ordinarily this can only be exercised from RegisterOperators() API |
| // but I am not going to bind that here |
| .def( |
| "def_legacy", |
| [](py::object self, const char* schema) { |
| TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); |
| self.cast<torch::Library&>().def(torch::jit::parseSchema(schema)); |
| return self; |
| }, |
| "", |
| py::arg("schema")) |
| // We can't conveniently turn Python functions into valid functions |
| // in the dispatcher. So instead we provide a bunch of precanned |
| // functions for testing purposes. You're NOT intended to actually |
| // call these functions; they're just here so we can actually register |
| // something |
| // |
| // Mangling scheme: args_rets. One character per. |
| // t = Tensor |
| .def( |
| "def_name_t_t", |
| [](py::object self, |
| const char* name, |
| const char* dispatch, |
| const char* debug) { |
| TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); |
| self.cast<torch::Library&>().def( |
| name, dispatch_str(dispatch, [](const at::Tensor& a) { |
| return a; |
| }).debug(debug)); |
| return self; |
| }, |
| "", |
| py::arg("name"), |
| py::arg("dispatch") = "", |
| py::arg("debug") = "default_def_name_t_t") |
| .def( |
| "def_schema_t_t", |
| [](py::object self, |
| const char* schema, |
| const char* dispatch, |
| const char* alias, |
| const char* debug) { |
| TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); |
| self.cast<torch::Library&>().def( |
| torch::schema(schema, parseAliasAnalysisKind(alias)), |
| dispatch_str(dispatch, [](const at::Tensor& a) { |
| return a; |
| }).debug(debug)); |
| return self; |
| }, |
| "", |
| py::arg("name"), |
| py::arg("dispatch") = "", |
| py::arg("alias") = "", |
| py::arg("debug") = "default_def_schema_t_t") |
| // TODO: maybe consider deduplicating the definitions here, it's getting |
| // pretty long |
| .def( |
| "impl_t_t", |
| [](py::object self, |
| const char* name, |
| const char* dispatch, |
| const char* debug) { |
| TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); |
| self.cast<torch::Library&>().impl( |
| name, dispatch_str(dispatch, [](const at::Tensor& a) { |
| return a; |
| }).debug(debug)); |
| return self; |
| }, |
| "", |
| py::arg("name"), |
| py::arg("dispatch") = "", |
| py::arg("debug") = "impl_t_t") |
| .def( |
| "impl_with_aoti_compile", |
| [](const py::object& self, |
| const char* ns, |
| const char* op_name_with_overload, |
| c10::DispatchKey dispatch) { |
| HANDLE_TH_ERRORS |
| std::string reg_op_name = |
| std::string(ns).append("::").append(op_name_with_overload); |
| |
| auto& lib = self.cast<torch::Library&>(); |
| lib.impl( |
| reg_op_name.c_str(), |
| torch::dispatch( |
| dispatch, |
| CppFunction::makeFromBoxedFunctor( |
| std::make_unique< |
| torch::inductor::AOTIPythonKernelHolder>( |
| dispatch, ns, op_name_with_overload))), |
| register_or_verify()); |
| END_HANDLE_TH_ERRORS_PYBIND |
| }, |
| "", |
| py::arg("ns"), |
| py::arg("op_name_with_overload"), |
| py::arg("dispatch")) |
| .def( |
| "impl", |
| [](const py::object& self, |
| const char* name, |
| // TODO: empty string no longer works |
| c10::DispatchKey dispatch, |
| py::object func, |
| bool with_keyset) { |
| HANDLE_TH_ERRORS |
| auto& lib = self.cast<torch::Library&>(); |
| if (func.is(py::module::import("torch.library") |
| .attr("fallthrough_kernel"))) { |
| lib.impl( |
| name, |
| torch::dispatch(dispatch, CppFunction::makeFallthrough()), |
| register_or_verify()); |
| } else { |
| lib.impl( |
| name, |
| torch::dispatch( |
| dispatch, |
| CppFunction::makeFromBoxedFunctor( |
| std::make_unique<PythonKernelHolder>( |
| func, dispatch, with_keyset))), |
| register_or_verify()); |
| python_registrations_[lib._resolve(name)].insert_or_assign( |
| dispatch, |
| std::make_shared<c10::SafePyObject>( |
| func.release().ptr(), getPyInterpreter())); |
| } |
| END_HANDLE_TH_ERRORS_PYBIND |
| }, |
| "", |
| py::arg("name"), |
| py::arg("dispatch"), |
| py::arg("func"), |
| py::arg("with_keyset") = false) |
| .def( |
| "define", |
| [](const py::object& self, |
| const char* schema, |
| const char* alias_analysis, |
| const std::vector<at::Tag>& tags) { |
| auto parsed_schema = |
| torch::schema(schema, parseAliasAnalysisKind(alias_analysis)); |
| self.cast<torch::Library&>().def( |
| std::move(parsed_schema), tags, register_or_verify()); |
| // TODO: this is dumb, had to make a second copy |
| return torch::schema(schema, parseAliasAnalysisKind(alias_analysis)) |
| .name(); |
| }, |
| "", |
| py::arg("schema"), |
| py::arg("alias_analysis") = "", |
| py::arg("tags") = std::vector<at::Tag>()) |
| .def( |
| "fallback_fallthrough", |
| [](py::object self, const char* dispatch) { |
| TORCH_INTERNAL_ASSERT(isMainPyInterpreter()); |
| self.cast<torch::Library&>().fallback( |
| dispatch_str(dispatch, CppFunction::makeFallthrough())); |
| return self; |
| }, |
| "", |
| py::arg("dispatch") = ""); |
| |
| m.def( |
| "_dispatch_library", |
| [](const char* kind, |
| std::string name, |
| const char* dispatch, |
| const char* file, |
| uint32_t linenum) { |
| HANDLE_TH_ERRORS |
| return std::make_unique<torch::Library>( |
| parseKind(kind), |
| std::move(name), |
| std::string(dispatch).empty() |
| ? c10::nullopt |
| : c10::make_optional(c10::parseDispatchKey(dispatch)), |
| "/dev/null", // temporary workaround |
| linenum); |
| END_HANDLE_TH_ERRORS_PYBIND |
| }, |
| "", |
| py::arg("kind"), |
| py::arg("name"), |
| py::arg("dispatch"), |
| py::arg("file") = "/dev/null", |
| py::arg("linenum") = 0); |
| |
| m.def( |
| "_dispatch_find_schema_or_throw", |
| [](const char* name, const char* overload_name) -> c10::OperatorHandle { |
| return c10::Dispatcher::singleton().findSchemaOrThrow( |
| name, overload_name); |
| }); |
| |
| m.def("_dispatch_dump", [](const char* name) -> std::string { |
| auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| if (!op) { |
| return ""; |
| } else { |
| return op->dumpState(); |
| } |
| }); |
| |
| m.def("_dispatch_dump_table", [](const char* name) -> std::string { |
| auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| if (!op) { |
| return ""; |
| } else { |
| return op->dumpComputedTable(); |
| } |
| }); |
| |
| m.def("_dispatch_check_invariants", [](const char* name) { |
| auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| if (!op) { |
| } else { |
| return op->checkInvariants(); |
| } |
| }); |
| |
| m.def("_dispatch_check_all_invariants", []() { |
| c10::Dispatcher::singleton().checkInvariants(); |
| }); |
| |
| m.def("_dispatch_has_kernel", [](const char* name) -> bool { |
| auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| return static_cast<bool>(op); |
| }); |
| |
| m.def( |
| // Returns whether or not a direct kernel registration exists |
| // for this <op_name, dispatch_key> pair. |
| "_dispatch_has_kernel_for_dispatch_key", |
| [](const char* name, c10::DispatchKey dispatch) -> bool { |
| auto op = |
| c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| TORCH_CHECK(op, "operator ", name, " does not exist"); |
| return op->hasKernelForDispatchKey(dispatch); |
| }); |
| |
| m.def( |
| // Returns whether or not the kernel for this dispatach key is a |
| // fallthrough kernel |
| "_dispatch_kernel_for_dispatch_key_is_fallthrough", |
| [](const char* name, c10::DispatchKey dispatch) -> bool { |
| auto op = |
| c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| return op->isKernelFallthroughKernel(dispatch); |
| }); |
| |
| m.def( |
| "_dispatch_has_kernel_for_any_dispatch_key", |
| [](const char* name, c10::DispatchKeySet ks) -> bool { |
| auto op = |
| c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| TORCH_CHECK(op, "operator ", name, " does not exist"); |
| return op->hasKernelForAnyDispatchKey(ks); |
| }); |
| |
| m.def( |
| // Returns whether or not there is an entry in the runtime computed |
| // dispatch table, for this <op_name, dispatch_key> pair. For example, if |
| // "op" has a `CompositeImplicitAutograd` kernel, Then |
| // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return |
| // true for all backends that are part of the alias set for |
| // CompositeImplicitAutograd. |
| "_dispatch_has_computed_kernel_for_dispatch_key", |
| [](const char* name, const char* dispatch) -> bool { |
| auto op = |
| c10::Dispatcher::singleton().findOp(torch::jit::parseName(name)); |
| TORCH_CHECK(op, "operator ", name, " does not exist"); |
| return op->hasComputedKernelForDispatchKey( |
| c10::parseDispatchKey(dispatch)); |
| }); |
| |
| m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> { |
| auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls(); |
| |
| std::vector<std::string> states; |
| states.reserve(danglingImpls.size()); |
| for (auto& danglingImpl : danglingImpls) { |
| states.emplace_back(danglingImpl.dumpState()); |
| } |
| |
| return states; |
| }); |
| |
| m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> { |
| auto op_names = c10::Dispatcher::singleton().getAllOpNames(); |
| |
| std::vector<std::string> names; |
| names.reserve(op_names.size()); |
| for (auto& op : op_names) { |
| std::stringstream ss; |
| ss << op.name; |
| if (!op.overload_name.empty()) { |
| ss << "." << op.overload_name; |
| } |
| names.emplace_back(ss.str()); |
| } |
| |
| return names; |
| }); |
| |
| m.def( |
| "_dispatch_tls_set_dispatch_key_excluded", |
| [](c10::DispatchKey dispatch_key, bool desired_state) { |
| c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state); |
| }); |
| m.def( |
| "_dispatch_tls_is_dispatch_key_excluded", |
| [](c10::DispatchKey dispatch_key) { |
| return c10::impl::tls_is_dispatch_key_excluded(dispatch_key); |
| }); |
| m.def( |
| "_dispatch_tls_set_dispatch_key_included", |
| [](c10::DispatchKey dispatch_key, bool desired_state) { |
| c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state); |
| }); |
| m.def( |
| "_dispatch_tls_is_dispatch_key_included", |
| [](c10::DispatchKey dispatch_key) { |
| return c10::impl::tls_is_dispatch_key_included(dispatch_key); |
| }); |
| |
| m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) { |
| return at::isTensorSubclassLike(tensor); |
| }); |
| |
| m.def("_dispatch_key_name", [](c10::DispatchKey k) { |
| return c10::toString(k); |
| }); |
| m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; }); |
| m.def("_to_functionality_key", [](c10::DispatchKey k) { |
| return c10::toFunctionalityKey(k); |
| }); |
| // E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of: |
| // AutogradCPU |
| // AutogradCUDA |
| // ... |
| // AutogradPrivateUse3 |
| m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) { |
| std::vector<c10::DispatchKey> keys; |
| if (c10::isPerBackendFunctionalityKey(key)) { |
| auto ks = c10::DispatchKeySet(key) | |
| c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask); |
| for (auto k : ks) { |
| keys.push_back(k); |
| } |
| } else { |
| keys.push_back(key); |
| } |
| return keys; |
| }); |
| m.def("_dispatch_num_backends", []() { return c10::num_backends; }); |
| |
| #define DEF_ONE(n) .value(#n, c10::DispatchKey::n) |
| |
| py::enum_<c10::DispatchKey>(m, "DispatchKey") |
| // clang-format off |
| DEF_ONE(Undefined) |
| DEF_ONE(CompositeExplicitAutogradNonFunctional) |
| DEF_ONE(CompositeExplicitAutograd) |
| DEF_ONE(CompositeImplicitAutogradNestedTensor) |
| DEF_ONE(CompositeImplicitAutograd) |
| // NestedTensor is not a backend key |
| DEF_ONE(AutogradNestedTensor) |
| DEF_ONE(AutogradOther) |
| DEF_ONE(Autograd) |
| DEF_ONE(Conjugate) |
| DEF_ONE(ZeroTensor) |
| DEF_ONE(Negative) |
| DEF_ONE(BackendSelect) |
| DEF_ONE(ADInplaceOrView) |
| DEF_ONE(PythonTLSSnapshot) |
| DEF_ONE(Python) |
| DEF_ONE(FuncTorchDynamicLayerFrontMode) |
| DEF_ONE(FuncTorchDynamicLayerBackMode) |
| DEF_ONE(FuncTorchBatchedDecomposition) |
| DEF_ONE(FuncTorchBatched) |
| DEF_ONE(FuncTorchVmapMode) |
| DEF_ONE(FuncTorchGradWrapper) |
| DEF_ONE(PythonDispatcher) |
| DEF_ONE(PreDispatch) |
| DEF_ONE(Functionalize) |
| DEF_ONE(AutocastCPU) |
| DEF_ONE(AutocastXPU) |
| DEF_ONE(AutocastHPU) |
| DEF_ONE(AutocastIPU) |
| DEF_ONE(AutocastCUDA) |
| DEF_ONE(AutocastPrivateUse1) |
| // clang-format on |
| |
| #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n) |
| #define DEF_MULTIPLE(fullname, prefix) \ |
| DEF_SINGLE(, fullname) \ |
| DEF_SINGLE(, StartOf##fullname##Backends) \ |
| C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \ |
| DEF_SINGLE(, EndOf##fullname##Backends) |
| |
| // clang-format off |
| C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE) |
| // clang-format on |
| |
| #undef DEF_MULTIPLE |
| #undef DEF_SINGLE |
| ; |
| |
| py::class_<c10::DispatchKeySet>(m, "DispatchKeySet") |
| .def(py::init<c10::DispatchKey>()) |
| .def("__or__", &c10::DispatchKeySet::operator|) |
| .def("__sub__", &c10::DispatchKeySet::operator-) |
| .def("__and__", &c10::DispatchKeySet::operator&) |
| .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId) |
| .def( |
| "remove", |
| [](c10::DispatchKeySet self, c10::DispatchKey k) { |
| return self.remove(k); |
| }) |
| .def( |
| "add", |
| [](c10::DispatchKeySet self, c10::DispatchKey k) { |
| return self.add(k); |
| }) |
| .def("has", &c10::DispatchKeySet::has) |
| .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); }); |
| |
| m.attr("_dispatch_autogradother_backends") = |
| py::cast(c10::autogradother_backends); |
| |
| m.attr("_additional_keys_to_prop_for_wrapper_tensors") = |
| py::cast(at::functorch::kKeysToPropagateToWrapper); |
| |
| m.attr("_after_autograd_keyset") = py::cast(c10::after_autograd_keyset); |
| m.attr("_after_ADInplaceOrView_keyset") = |
| py::cast(c10::after_ADInplaceOrView_keyset); |
| |
| m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) { |
| return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t); |
| }); |
| |
| m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) { |
| return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t); |
| }); |
| |
| m.def("_dispatch_keyset_full", []() { |
| return c10::DispatchKeySet(c10::DispatchKeySet::FULL); |
| }); |
| |
| m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey); |
| |
| m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) { |
| return c10::toString(keyset); |
| }); |
| |
| m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) { |
| return c10::getBackendKeySetFromAutograd(k); |
| }); |
| |
| m.def("_dispatch_keys", [](const at::Tensor& tensor) { |
| auto* impl = tensor.unsafeGetTensorImpl(); |
| return impl->key_set(); |
| }); |
| m.def("_dispatch_tls_local_include_set", []() { |
| return c10::impl::tls_local_dispatch_key_set().included_; |
| }); |
| m.def("_dispatch_tls_local_exclude_set", []() { |
| return c10::impl::tls_local_dispatch_key_set().excluded_; |
| }); |
| m.def("_functionalization_reapply_views_tls", []() { |
| return at::functionalization::impl::getFunctionalizationReapplyViewsTLS(); |
| }); |
| m.def( |
| "_dispatch_is_included_in_alias", |
| [](c10::DispatchKey a, c10::DispatchKey b) { |
| return c10::isIncludedInAlias(a, b); |
| }); |
| |
| // DEPRECATED, please don't use this. Instead use |
| // torch._C._ExcludeDispatchKeyGuard |
| py_context_manager_DEPRECATED< |
| c10::impl::ExcludeDispatchKeyGuard, |
| c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard"); |
| |
| py_context_manager< |
| c10::impl::ForceDispatchKeyGuard, |
| c10::DispatchKeySet, |
| c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard"); |
| py_context_manager<c10::impl::ForceDispatchKeyGuard>( |
| m, "_PreserveDispatchKeyGuard"); |
| py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>( |
| m, "_IncludeDispatchKeyGuard"); |
| py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>( |
| m, "_ExcludeDispatchKeyGuard"); |
| py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>( |
| m, "_SetExcludeDispatchKeyGuard"); |
| |
| py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>( |
| m, "_AutoDispatchBelowAutograd"); |
| py_context_manager<at::AutoDispatchBelowADInplaceOrView>( |
| m, "_AutoDispatchBelowADInplaceOrView"); |
| |
| // Prints out the name of every operator that has a kernel registered to the |
| // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print |
| // out the name of every operator that the Dispatcher knows of. This can be |
| // useful to answer questions like "list all operators that do not have a CPU |
| // kernel". |
| m.def( |
| "_dispatch_print_registrations_for_dispatch_key", |
| [](const char* dispatch_key = "") { |
| auto k = std::string(dispatch_key).empty() |
| ? c10::nullopt |
| : c10::make_optional(c10::parseDispatchKey(dispatch_key)); |
| auto op_names = |
| c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); |
| for (auto& op : op_names) { |
| std::cout << op << '\n'; |
| } |
| }, |
| py::arg("dispatch_key") = static_cast<const char*>("")); |
| |
| m.def( |
| "_parse_dispatch_key", |
| [](const char* dispatch_key) -> std::optional<c10::DispatchKey> { |
| try { |
| return c10::parseDispatchKey(dispatch_key); |
| } catch (const c10::Error& err) { |
| return c10::nullopt; |
| } |
| }); |
| |
| m.def( |
| "_dispatch_get_registrations_for_dispatch_key", |
| [](const char* dispatch_key = "") { |
| auto k = std::string(dispatch_key).empty() |
| ? c10::nullopt |
| : c10::make_optional(c10::parseDispatchKey(dispatch_key)); |
| auto op_names = |
| c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k); |
| std::vector<std::string> names; |
| names.reserve(op_names.size()); |
| for (auto& op : op_names) { |
| names.emplace_back( |
| op.name + |
| (op.overload_name.empty() ? "" : "." + op.overload_name)); |
| } |
| return names; |
| }, |
| py::arg("dispatch_key") = static_cast<const char*>("")); |
| m.def( |
| "_dispatch_set_report_error_callback", |
| [](c10::OperatorHandle& handle, py::object callback) { |
| auto obj = callback.release().ptr(); |
| auto callback_obj = |
| std::make_unique<c10::SafePyObject>(obj, getPyInterpreter()); |
| handle.setReportErrorCallback_(std::move(callback_obj)); |
| }); |
| |
| m.def( |
| "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); }); |
| m.def("_dispatch_pystub", [](const char* name, const char* overload) { |
| return c10::Dispatcher::singleton().getPyStub( |
| c10::OperatorName(name, overload)); |
| }); |
| |
| m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) { |
| return at::functionalization::impl::replace_(a, b); |
| }); |
| m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) { |
| at::functionalization::impl::propagate_xla_data(a, b); |
| }); |
| m.def("_commit_update", [](const at::Tensor& a) { |
| return at::functionalization::impl::commit_update(a); |
| }); |
| m.def("_unsafe_reset_storage", [](const at::Tensor& a) { |
| return at::functionalization::impl::unsafe_reset_storage(a); |
| }); |
| |
| m.def("_dispatch_key_for_device", [](const std::string& device_type) { |
| auto device = c10::Device(device_type); |
| TORCH_CHECK( |
| !device.has_index(), |
| "Expected device_type string to not have a device index; got ", |
| device_type); |
| return c10::toString( |
| c10::computeDispatchKey(c10::nullopt, c10::nullopt, device)); |
| }); |
| |
| m.def("_are_functorch_transforms_active", []() { |
| auto include_set = c10::impl::tls_local_dispatch_key_set().included_; |
| return ( |
| include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) || |
| include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode)); |
| }); |
| |
| m.def("_get_nested_int", [](int64_t data, int64_t coeff) { |
| return c10::SymInt(c10::SymNode( |
| c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff))); |
| }); |
| |
| m.def("_get_constant_bool_symnode", [](int64_t data) { |
| return c10::SymNode( |
| c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data)); |
| }); |
| |
| m.def("_non_sym_sizes", [](const at::Tensor& a) { |
| return a.sizes(); // NB: NOT sym_size |
| }); |
| |
| m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) { |
| if (!t.unsafeGetTensorImpl()->has_storage()) { |
| // If the Tensor doesn't have a storage, then accessing .data_ptr() |
| // will already raise an error. |
| return; |
| } |
| // Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr |
| // will throw. |
| t.unsafeGetTensorImpl() |
| ->storage() |
| .unsafeGetStorageImpl() |
| ->set_throw_on_mutable_data_ptr(); |
| }); |
| |
| // Invariant: you must ONLY call this with FakeTensors. |
| m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) { |
| if (!t.unsafeGetTensorImpl()->has_storage()) { |
| // If the Tensor doesn't have a storage, then accessing .data_ptr() |
| // will already raise an error. |
| return; |
| } |
| t.unsafeGetTensorImpl() |
| ->storage() |
| .unsafeGetStorageImpl() |
| ->set_warn_deprecated_on_mutable_data_ptr(); |
| }); |
| |
| m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors); |
| m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors); |
| |
| using c10::impl::TorchDispatchModeKey; |
| py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey") |
| .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL) |
| .value("PROXY", TorchDispatchModeKey::PROXY) |
| .value("FAKE", TorchDispatchModeKey::FAKE); |
| } |
| |
| // TODO: dedupe with the kernel |
| void python_op_registration_trampoline_impl( |
| const c10::OperatorHandle& op, |
| c10::DispatchKey key, |
| c10::DispatchKeySet keyset, |
| torch::jit::Stack* stack, |
| bool with_keyset) { |
| auto arguments = torch::jit::pop(*stack, op.schema().arguments().size()); |
| py::gil_scoped_acquire g; |
| auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
| const auto& func = python_registrations_[op.operator_name()][key]; |
| TORCH_INTERNAL_ASSERT(func != nullptr); |
| auto* pyobj = func->ptr(getPyInterpreter()); |
| TORCH_INTERNAL_ASSERT(pyobj != nullptr); |
| auto callable = py::reinterpret_borrow<py::object>(pyobj); |
| auto obj = with_keyset |
| ? callable(keyset, *args_kwargs.first, **args_kwargs.second) |
| : callable(*args_kwargs.first, **args_kwargs.second); |
| if (!obj) { |
| throw python_error(); |
| } |
| pushPyOutToStack(op, stack, obj, "PythonKernelHolder"); |
| } |
| |
| } // namespace torch::impl::dispatch |