| #include <torch/csrc/python_headers.h> |
| |
| #include <ATen/PythonTorchFunctionTLS.h> |
| #include <ATen/SavedTensorHooks.h> |
| #include <ATen/autocast_mode.h> |
| #include <ATen/core/PythonFallbackKernel.h> |
| #include <ATen/record_function.h> |
| #include <c10/core/DeviceType.h> |
| #include <c10/core/InferenceMode.h> |
| #include <c10/core/ScalarType.h> |
| #include <c10/core/impl/PythonDispatcherTLS.h> |
| #include <torch/csrc/Exceptions.h> |
| #include <torch/csrc/autograd/VariableTypeUtils.h> |
| #include <torch/csrc/autograd/autograd.h> |
| #include <torch/csrc/autograd/function.h> |
| #include <torch/csrc/autograd/grad_mode.h> |
| #include <torch/csrc/autograd/profiler.h> |
| #include <torch/csrc/autograd/profiler_python.h> |
| #include <torch/csrc/autograd/python_function.h> |
| #include <torch/csrc/autograd/python_saved_variable_hooks.h> |
| #include <torch/csrc/autograd/python_variable.h> |
| #include <torch/csrc/autograd/record_function_ops.h> |
| #include <torch/csrc/autograd/saved_variable.h> |
| #include <torch/csrc/autograd/utils/python_arg_parsing.h> |
| #include <torch/csrc/autograd/utils/wrap_outputs.h> |
| #include <torch/csrc/jit/python/pybind_utils.h> |
| #include <torch/csrc/profiler/collection.h> |
| #include <torch/csrc/profiler/kineto_shim.h> |
| #include <torch/csrc/utils.h> |
| #include <torch/csrc/utils/disable_torch_function.h> |
| #include <torch/csrc/utils/pybind.h> |
| #include <torch/csrc/utils/pycfunction_helpers.h> |
| #include <torch/csrc/utils/python_torch_function_mode.h> |
| |
| #include <set> |
| #include <unordered_set> |
| #include <utility> |
| |
| namespace { |
| |
| struct DisableFuncTorch { |
| DisableFuncTorch() |
| : front_guard_(c10::DispatchKey::FuncTorchDynamicLayerFrontMode), |
| back_guard_(c10::DispatchKey::FuncTorchDynamicLayerBackMode) {} |
| c10::impl::ExcludeDispatchKeyGuard front_guard_; |
| c10::impl::ExcludeDispatchKeyGuard back_guard_; |
| }; |
| |
| struct MultithreadingEnabled { |
| MultithreadingEnabled(bool enabled) |
| : old_(c10::AutogradState::get_tls_state().get_multithreading_enabled()) { |
| c10::AutogradState::get_tls_state().set_multithreading_enabled(enabled); |
| } |
| ~MultithreadingEnabled() { |
| c10::AutogradState::get_tls_state().set_multithreading_enabled(old_); |
| } |
| bool old_; |
| }; |
| |
| struct ViewReplayEnabled { |
| ViewReplayEnabled(bool enabled) |
| : old_(c10::AutogradState::get_tls_state().get_view_replay_enabled()) { |
| c10::AutogradState::get_tls_state().set_view_replay_enabled(enabled); |
| } |
| ~ViewReplayEnabled() { |
| c10::AutogradState::get_tls_state().set_view_replay_enabled(old_); |
| } |
| bool old_; |
| }; |
| |
| struct DisableAutocast { |
| c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset}; |
| }; |
| |
| struct EnableTorchFunction { |
| EnableTorchFunction() |
| : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) { |
| at::impl::PythonTorchFunctionTLS::set_disabled_state( |
| at::impl::TorchFunctionDisabledState::ENABLED); |
| } |
| ~EnableTorchFunction() { |
| at::impl::PythonTorchFunctionTLS::set_disabled_state(old_); |
| } |
| at::impl::TorchFunctionDisabledState old_; |
| }; |
| |
| struct EnablePythonDispatcher { |
| EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) { |
| c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter()); |
| } |
| ~EnablePythonDispatcher() { |
| c10::impl::PythonDispatcherTLS::set_state(old_); |
| } |
| c10::impl::PyInterpreter* old_; |
| }; |
| |
| } // namespace |
| |
| PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) { |
| using namespace torch::autograd::profiler; |
| using namespace torch::profiler::impl; |
| auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor")); |
| if (!tensor_module) |
| return nullptr; |
| |
| // NOTE: "leaks" THPVariableClass |
| THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor"); |
| if (!THPVariableClass) |
| return nullptr; |
| |
| auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.autograd")); |
| if (!autograd_module) |
| return nullptr; |
| |
| // NOTE: "leaks" Function |
| THPFunctionClass = PyObject_GetAttrString(autograd_module, "Function"); |
| if (!THPFunctionClass) |
| return nullptr; |
| |
| auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C")); |
| if (!torch_C_module) |
| return nullptr; |
| auto _C_m = py::handle(torch_C_module).cast<py::module>(); |
| auto m = _C_m.def_submodule("_autograd", "autograd bindings"); |
| |
| auto parameter_module = |
| THPObjectPtr(PyImport_ImportModule("torch.nn.parameter")); |
| if (!parameter_module) |
| return nullptr; |
| |
| // NOTE: "leaks" ParameterClass |
| ParameterClass = PyObject_GetAttrString(parameter_module, "Parameter"); |
| if (!ParameterClass) |
| return nullptr; |
| |
| py::class_<LegacyEvent>(m, "ProfilerEvent") |
| .def("kind", &LegacyEvent::kindStr) |
| .def("name", [](const LegacyEvent& e) { return e.name(); }) |
| .def("thread_id", &LegacyEvent::threadId) |
| .def("fwd_thread_id", &LegacyEvent::fwdThreadId) |
| .def("device", &LegacyEvent::device) |
| .def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs) |
| .def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs) |
| .def("has_cuda", &LegacyEvent::hasCuda) |
| .def("shapes", &LegacyEvent::shapes) |
| .def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage) |
| .def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage) |
| .def("handle", &LegacyEvent::handle) |
| .def("node_id", &LegacyEvent::nodeId) |
| .def("is_remote", &LegacyEvent::isRemote) |
| .def("sequence_nr", &LegacyEvent::sequenceNr) |
| .def("stack", &LegacyEvent::stack) |
| .def("scope", &LegacyEvent::scope) |
| .def("correlation_id", &LegacyEvent::correlationId) |
| .def("start_us", &LegacyEvent::cpuUs) |
| .def("flops", &LegacyEvent::flops) |
| .def("is_async", &LegacyEvent::isAsync); |
| |
| py::enum_<c10::DeviceType>(m, "DeviceType") |
| .value("CPU", c10::DeviceType::CPU) |
| .value("CUDA", c10::DeviceType::CUDA) |
| .value("MKLDNN", c10::DeviceType::MKLDNN) |
| .value("OPENGL", c10::DeviceType::OPENGL) |
| .value("OPENCL", c10::DeviceType::OPENCL) |
| .value("IDEEP", c10::DeviceType::IDEEP) |
| .value("HIP", c10::DeviceType::HIP) |
| .value("FPGA", c10::DeviceType::FPGA) |
| .value("ORT", c10::DeviceType::ORT) |
| .value("XLA", c10::DeviceType::XLA) |
| .value("Vulkan", c10::DeviceType::Vulkan) |
| .value("Metal", c10::DeviceType::Metal) |
| .value("XPU", c10::DeviceType::XPU) |
| .value("MPS", c10::DeviceType::MPS) |
| .value("Meta", c10::DeviceType::Meta) |
| .value("HPU", c10::DeviceType::HPU) |
| .value("VE", c10::DeviceType::VE) |
| .value("Lazy", c10::DeviceType::Lazy) |
| .value("IPU", c10::DeviceType::IPU); |
| |
| py::class_<KinetoEvent>(m, "_KinetoEvent") |
| // name of the event |
| .def("name", [](const KinetoEvent& e) { return e.name(); }) |
| // PyTorch thread id of the start callback |
| .def( |
| "start_thread_id", |
| [](const KinetoEvent& e) { return e.startThreadId(); }) |
| // PyTorch thread id of the end callback |
| .def( |
| "end_thread_id", [](const KinetoEvent& e) { return e.endThreadId(); }) |
| // for events of scope BACKWARD_FUNCTION - PyTorch thread id |
| // of the corresponding forward op |
| .def( |
| "fwd_thread_id", [](const KinetoEvent& e) { return e.fwdThreadId(); }) |
| // together with fwd_thread_id, used to uniquely identify |
| // the forward op |
| .def("sequence_nr", [](const KinetoEvent& e) { return e.sequenceNr(); }) |
| // absolute start time (since unix epoch) in us |
| .def("start_us", [](const KinetoEvent& e) { return e.startUs(); }) |
| // duration in us |
| .def("duration_us", [](const KinetoEvent& e) { return e.durationUs(); }) |
| // used for correlation between high-level PyTorch events |
| // and low-level device events |
| .def( |
| "correlation_id", |
| [](const KinetoEvent& e) { return e.correlationId(); }) |
| // shapes of input tensors |
| .def("shapes", [](const KinetoEvent& e) { return e.shapes().vec(); }) |
| .def("dtypes", [](const KinetoEvent& e) { return e.dtypes().vec(); }) |
| // stack traces of the PyTorch CPU events |
| .def("stack", [](const KinetoEvent& e) { return e.stack().vec(); }) |
| // type of the RecordFunction that generated a PyTorch CPU event |
| // (op, torchscript function, user label, etc) |
| .def("scope", [](const KinetoEvent& e) { return e.scope(); }) |
| // device number, for CPU - process id |
| .def("device_index", [](const KinetoEvent& e) { return e.deviceIndex(); }) |
| // for CUDA - stream id, for CPU - start thread id |
| .def( |
| "device_resource_id", |
| [](const KinetoEvent& e) { return e.deviceResourceId(); }) |
| // device type |
| .def("device_type", [](const KinetoEvent& e) { return e.deviceType(); }) |
| // correlation id of a linked event |
| .def( |
| "linked_correlation_id", |
| [](const KinetoEvent& e) { return e.linkedCorrelationId(); }) |
| // compute flops |
| .def("flops", [](const KinetoEvent& e) { return e.flops(); }) |
| // Whether this is async event or not |
| .def("is_async", [](const KinetoEvent& e) { return e.isAsync(); }) |
| .def("cuda_elapsed_us", &KinetoEvent::cudaElapsedUs) |
| .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); }); |
| |
| m.def("_soft_assert_raises", &setSoftAssertRaises); |
| |
| py::class_<ProfilerResult>(m, "_ProfilerResult") |
| .def("trace_start_us", &ProfilerResult::trace_start_us) |
| .def("events", &ProfilerResult::events) |
| .def("experimental_event_tree", &ProfilerResult::event_tree) |
| #ifdef USE_KINETO |
| .def("save", &ProfilerResult::save) |
| #endif // USE_KINETO |
| ; |
| |
| m.def( |
| "_enable_profiler", |
| &enableProfiler, |
| py::arg("config"), |
| py::arg("activities"), |
| py::arg("scopes") = std::unordered_set<at::RecordScope>()); |
| m.def("_disable_profiler", disableProfiler); |
| m.def("_prepare_profiler", prepareProfiler); |
| m.def("_add_metadata_json", addMetadataJson); // Only if `USE_KINETO` is set |
| m.def("_kineto_step", profilerStep); // Only if `USE_KINETO` is set |
| m.def("kineto_available", []() { return torch::profiler::kKinetoAvailable; }); |
| |
| // NOTICE: These record functions are not torch operators and may not show up |
| // in TorchScript tracing, FX transforms, or operator serialization. For these |
| // use cases, please use `torch.profiler.record_function`. |
| // Creates a new profiling scope using RecordFunction and invokes its starting |
| // callbacks. |
| m.def( |
| "_record_function_with_args_enter", |
| [](const std::string& name, py::args args) { |
| using torch::autograd::profiler::PythonRecordFunction; |
| auto python_rec = c10::make_intrusive<PythonRecordFunction>( |
| at::RecordScope::USER_SCOPE); |
| auto* rec = &python_rec->record; |
| if (rec->isActive()) { |
| if (rec->needsInputs()) { |
| auto iv_inputs = std::vector<c10::IValue>(); |
| for (const auto& arg : args) { |
| iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg)); |
| } |
| rec->before( |
| name, |
| c10::ArrayRef<const c10::IValue>( |
| iv_inputs.data(), iv_inputs.size())); |
| } else { |
| rec->before(name); |
| } |
| } |
| return torch::jit::toPyObject(std::move(python_rec)); |
| }); |
| |
| // Ends the profiling scope created with record_function_with_param_enter. |
| m.def("_record_function_with_args_exit", [](const py::object& obj) { |
| using torch::autograd::profiler::PythonRecordFunction; |
| auto python_record = torch::jit::toCustomClass<PythonRecordFunction>(obj); |
| |
| // We don't actually need to do anything with handle just need to persist |
| // the lifetime until now. |
| python_record->record.end(); |
| }); |
| |
| m.def("_supported_activities", []() { |
| std::set<ActivityType> activities{ActivityType::CPU}; |
| #if defined(USE_KINETO) && \ |
| (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER)) |
| if (at::getNumGPUs() > 0) { |
| activities.insert(ActivityType::CUDA); |
| } |
| #elif defined(USE_KINETO) |
| if (at::hasXPU()) { |
| activities.insert(ActivityType::XPU); |
| } |
| #endif |
| return activities; |
| }); |
| |
| m.def("_unsafe_set_version_counter", [](at::Tensor t, int64_t i) { |
| auto vc = torch::autograd::impl::version_counter(t); |
| vc.set_version(i); |
| }); |
| |
| m.def("_enable_profiler_legacy", enableProfilerLegacy); |
| py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions") |
| .def(py::init<bool, bool>()); |
| m.def( |
| "_disable_profiler_legacy", |
| disableProfilerLegacy, |
| py::arg("profiler_disable_options") = ProfilerDisableOptions()); |
| m.def("_profiler_enabled", profilerEnabled); |
| m.def("_profiler_type", torch::profiler::impl::profilerType); |
| m.def("_enable_record_function", [](bool enable) { |
| at::enableRecordFunction(enable); |
| }); |
| m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { |
| auto cb = |
| at::RecordFunctionCallback(nullptr).needsInputs(true).samplingProb( |
| sampling_prob); |
| if (is_global) { |
| at::addGlobalCallback(cb); |
| } else { |
| at::addThreadLocalCallback(cb); |
| } |
| }); |
| m.def("_clear_callbacks", []() { at::clearCallbacks(); }); |
| m.def( |
| "_saved_tensors_hooks_is_enabled", |
| at::SavedTensorDefaultHooks::is_enabled); |
| m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable); |
| m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable); |
| m.def( |
| "_saved_tensors_hooks_get_disabled_error_message", |
| at::SavedTensorDefaultHooks::get_disabled_error_message); |
| m.def( |
| "_push_saved_tensors_default_hooks", |
| [](py::function& pack_hook, py::function& unpack_hook) { |
| torch::autograd::PyDefaultSavedVariableHooks::push_hooks( |
| pack_hook, unpack_hook); |
| }); |
| m.def("_pop_saved_tensors_default_hooks", []() { |
| torch::autograd::PyDefaultSavedVariableHooks::pop_hooks(); |
| }); |
| |
| _C_m.def( |
| "_register_py_class_for_device", |
| [](const std::string& device, py::object python_type_class) { |
| auto cls = python_type_class.ptr(); |
| registerPythonTensorClass(device, cls); |
| }); |
| |
| _C_m.def("_activate_cuda_trace", []() { activateCUDATrace(); }); |
| |
| py::class_<c10::InferenceMode>(_C_m, "_InferenceMode").def(py::init<bool>()); |
| |
| py::class_<at::impl::RestorePythonTLSSnapshot>( |
| _C_m, "_RestorePythonTLSSnapshot") |
| .def(py::init<>()); |
| |
| py::class_<torch::DisableTorchDispatch>(_C_m, "_DisableTorchDispatch") |
| .def(py::init<>()); |
| py::class_<EnableTorchFunction>(_C_m, "_EnableTorchFunction") |
| .def(py::init<>()); |
| py::class_<EnablePythonDispatcher>(_C_m, "_EnablePythonDispatcher") |
| .def(py::init<>()); |
| py::class_<c10::impl::DisablePythonDispatcher>( |
| _C_m, "_DisablePythonDispatcher") |
| .def(py::init<>()); |
| py::class_<DisableFuncTorch>(_C_m, "_DisableFuncTorch").def(py::init<>()); |
| py::class_<MultithreadingEnabled>(_C_m, "_MultithreadingEnabled") |
| .def(py::init<bool>()); |
| py::class_<DisableAutocast>(std::move(_C_m), "_DisableAutocast") |
| .def(py::init<>()); |
| py::class_<ViewReplayEnabled>(_C_m, "_ViewReplayEnabled") |
| .def(py::init<bool>()); |
| py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor") |
| .def(py::init([]() -> torch::autograd::SavedVariable { |
| TORCH_CHECK( |
| false, |
| "Trying to create a SavedTensor object from Python is forbidden."); |
| })) |
| .def( |
| "register_hooks", |
| [](torch::autograd::SavedVariable& s, |
| py::function& pack_hook, |
| py::function& unpack_hook) { |
| // Because we use a py::object, pybind will increment the refcount |
| // of the hook functions for us |
| s.register_hooks( |
| std::make_unique<torch::autograd::PySavedVariableHooks>( |
| pack_hook, unpack_hook)); |
| }); |
| |
| torch::autograd::profiler::python_tracer::init(); |
| Py_RETURN_TRUE; |
| } |
| |
| namespace torch { |
| namespace autograd { |
| |
| static PyObject* set_autocast_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (!PyBool_Check(arg)) { |
| throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); |
| } |
| at::autocast::set_enabled(arg == Py_True); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (at::autocast::is_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() || |
| at::autocast::is_xpu_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (!PyBool_Check(arg)) { |
| throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); |
| } |
| at::autocast::set_cpu_enabled(arg == Py_True); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (at::autocast::is_cpu_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (!THPDtype_Check(arg)) { |
| throw TypeError( |
| "dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name); |
| } |
| at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type; |
| at::autocast::set_autocast_gpu_dtype(targetType); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (!THPDtype_Check(arg)) { |
| throw TypeError( |
| "dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name); |
| } |
| at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type; |
| at::autocast::set_autocast_cpu_dtype(targetType); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype(); |
| auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); |
| Py_INCREF(dtype); |
| return dtype; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype(); |
| auto dtype = (PyObject*)torch::getTHPDtype(current_dtype); |
| Py_INCREF(dtype); |
| return dtype; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| at::autocast::clear_cache(); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* autocast_increment_nesting(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| return THPUtils_packInt64(at::autocast::increment_nesting()); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* autocast_decrement_nesting(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| return THPUtils_packInt64(at::autocast::decrement_nesting()); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (at::autocast::is_autocast_cache_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (!PyBool_Check(arg)) { |
| throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); |
| } |
| at::autocast::set_autocast_cache_enabled(arg == Py_True); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* set_grad_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (!PyBool_Check(arg)) { |
| throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); |
| } |
| GradMode::set_enabled(arg == Py_True); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_grad_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (GradMode::is_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* set_fwd_grad_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (!PyBool_Check(arg)) { |
| throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name); |
| } |
| c10::AutogradState::get_tls_state().set_fw_grad_mode(arg == Py_True); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (c10::AutogradState::get_tls_state().get_fw_grad_mode()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (c10::InferenceMode::is_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* set_anomaly_mode_enabled( |
| PyObject* _unused, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| static PythonArgParser parser({ |
| "set_anomaly_enabled(bool enabled, bool check_nan=True)", |
| }); |
| ParsedArgs<2> parsed_args; |
| auto r = parser.parse(args, kwargs, parsed_args); |
| AnomalyMode::set_enabled(r.toBool(0), r.toBool(1)); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_anomaly_mode_enabled(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (AnomalyMode::is_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_anomaly_check_nan_enabled( |
| PyObject* _unused, |
| PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (AnomalyMode::should_check_nan()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* python_enter_dual_level(PyObject* _unused, PyObject* arg) { |
| HANDLE_TH_ERRORS |
| // It is unlikely that the depth of forward nesting will overflow int64_t so |
| // we just static cast here. |
| return utils::wrap(static_cast<int64_t>(forward_ad::enter_dual_level())); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* python_exit_dual_level( |
| PyObject* _unused, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| static PythonArgParser parser({"exit_dual_level(int64_t level)"}); |
| |
| ParsedArgs<1> parsed_args; |
| auto _r = parser.parse(args, kwargs, parsed_args); |
| |
| auto idx = _r.toInt64(0); |
| // Make sure the given index is valid before casting it |
| TORCH_CHECK(idx >= 0, "Dual level must be a positive number."); |
| forward_ad::exit_dual_level(static_cast<uint64_t>(idx)); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* is_torch_function_mode_enabled( |
| PyObject* _unused, |
| PyObject* _unused2) { |
| HANDLE_TH_ERRORS |
| if (at::impl::torch_function_mode_enabled()) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* push_on_torch_function_stack( |
| PyObject* _unused, |
| PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (arg != Py_None) { |
| Py_INCREF(arg); |
| at::impl::PythonTorchFunctionTLS::push_onto_stack( |
| std::make_shared<c10::SafePyObject>(arg, getPyInterpreter())); |
| } |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* pop_torch_function_stack( |
| PyObject* _unused, |
| PyObject* _unused2) { |
| HANDLE_TH_ERRORS |
| const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack(); |
| auto* r = mode->ptr(getPyInterpreter()); |
| Py_INCREF(r); |
| return r; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* get_function_stack_at( |
| PyObject* _unused, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| static PythonArgParser parser({"get_stack_at(int64_t level)"}); |
| |
| ParsedArgs<1> parsed_args; |
| auto _r = parser.parse(args, kwargs, parsed_args); |
| |
| auto idx = _r.toInt64(0); |
| const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx); |
| auto* r = mode->ptr(getPyInterpreter()); |
| Py_INCREF(r); |
| return r; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* len_torch_function_stack( |
| PyObject* _unused, |
| PyObject* _unused2) { |
| HANDLE_TH_ERRORS |
| const auto len = at::impl::PythonTorchFunctionTLS::stack_len(); |
| return utils::wrap(static_cast<int64_t>(len)); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* push_on_torch_dispatch_stack( |
| PyObject* _unused, |
| PyObject* arg) { |
| HANDLE_TH_ERRORS |
| if (arg != Py_None) { |
| Py_INCREF(arg); |
| c10::impl::TorchDispatchModeTLS::push_onto_stack( |
| std::make_shared<c10::SafePyObject>(arg, getPyInterpreter())); |
| } |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* pop_torch_dispatch_stack( |
| PyObject* _unused, |
| PyObject* _unused2) { |
| HANDLE_TH_ERRORS |
| const auto& mode = c10::impl::TorchDispatchModeTLS::pop_stack(); |
| auto* r = mode->ptr(getPyInterpreter()); |
| Py_INCREF(r); |
| return r; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* get_dispatch_stack_at( |
| PyObject* _unused, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| static PythonArgParser parser({"get_stack_at(int64_t level)"}); |
| |
| ParsedArgs<1> parsed_args; |
| auto _r = parser.parse(args, kwargs, parsed_args); |
| |
| auto idx = _r.toInt64(0); |
| const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx); |
| auto* r = mode->ptr(getPyInterpreter()); |
| Py_INCREF(r); |
| return r; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static PyObject* len_torch_dispatch_stack( |
| PyObject* _unused, |
| PyObject* _unused2) { |
| HANDLE_TH_ERRORS |
| const auto len = c10::impl::TorchDispatchModeTLS::stack_len(); |
| return utils::wrap(static_cast<int64_t>(len)); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPModule_increment_version(PyObject* _unused, PyObject* tensor) { |
| HANDLE_TH_ERRORS |
| THPUtils_assert( |
| THPVariable_Check(tensor), "increment_version expect a Tensor as input"); |
| torch::autograd::increment_version((THPVariable_Unpack(tensor))); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| // autograd methods on torch._C |
| static PyMethodDef methods[] = { // NOLINT |
| {"_set_grad_enabled", set_grad_enabled, METH_O, nullptr}, |
| {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr}, |
| {"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr}, |
| {"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr}, |
| {"is_inference_mode_enabled", |
| is_inference_mode_enabled, |
| METH_NOARGS, |
| nullptr}, |
| {"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr}, |
| {"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr}, |
| {"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr}, |
| {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr}, |
| {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr}, |
| {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr}, |
| {"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr}, |
| {"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr}, |
| {"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr}, |
| {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr}, |
| {"autocast_increment_nesting", |
| autocast_increment_nesting, |
| METH_NOARGS, |
| nullptr}, |
| {"autocast_decrement_nesting", |
| autocast_decrement_nesting, |
| METH_NOARGS, |
| nullptr}, |
| {"is_autocast_cache_enabled", |
| is_autocast_cache_enabled, |
| METH_NOARGS, |
| nullptr}, |
| {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr}, |
| {"_increment_version", THPModule_increment_version, METH_O, nullptr}, |
| {"set_anomaly_enabled", |
| castPyCFunctionWithKeywords(set_anomaly_mode_enabled), |
| METH_VARARGS | METH_KEYWORDS, |
| nullptr}, |
| {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr}, |
| {"is_anomaly_check_nan_enabled", |
| is_anomaly_check_nan_enabled, |
| METH_NOARGS, |
| nullptr}, |
| {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr}, |
| {"_exit_dual_level", |
| castPyCFunctionWithKeywords(python_exit_dual_level), |
| METH_VARARGS | METH_KEYWORDS, |
| nullptr}, |
| {"_is_torch_function_mode_enabled", |
| is_torch_function_mode_enabled, |
| METH_NOARGS, |
| nullptr}, |
| {"_push_on_torch_function_stack", |
| push_on_torch_function_stack, |
| METH_O, |
| nullptr}, |
| {"_pop_torch_function_stack", |
| pop_torch_function_stack, |
| METH_NOARGS, |
| nullptr}, |
| {"_get_function_stack_at", |
| castPyCFunctionWithKeywords(get_function_stack_at), |
| METH_VARARGS | METH_KEYWORDS, |
| nullptr}, |
| {"_len_torch_function_stack", |
| len_torch_function_stack, |
| METH_NOARGS, |
| nullptr}, |
| {"_push_on_torch_dispatch_stack", |
| push_on_torch_dispatch_stack, |
| METH_O, |
| nullptr}, |
| {"_pop_torch_dispatch_stack", |
| pop_torch_dispatch_stack, |
| METH_NOARGS, |
| nullptr}, |
| {"_get_dispatch_stack_at", |
| castPyCFunctionWithKeywords(get_dispatch_stack_at), |
| METH_VARARGS | METH_KEYWORDS, |
| nullptr}, |
| {"_len_torch_dispatch_stack", |
| len_torch_dispatch_stack, |
| METH_NOARGS, |
| nullptr}, |
| {nullptr, nullptr, 0, nullptr}}; |
| |
| PyMethodDef* python_functions() { |
| return methods; |
| } |
| |
| } // namespace autograd |
| } // namespace torch |