| #include <ATen/core/PythonFallbackKernel.h> |
| #include <ATen/core/PythonOpRegistrationTrampoline.h> |
| #include <torch/csrc/PyInterpreter.h> |
| #include <torch/csrc/THP.h> |
| #include <torch/csrc/autograd/generated/VariableType.h> |
| #include <torch/csrc/utils/python_arg_parser.h> |
| #include <torch/csrc/utils/python_dispatch.h> |
| |
| #include <string> |
| |
| using namespace torch; |
| using namespace at; |
| using namespace c10; |
| |
| namespace { |
| |
| // NB: This is a macro and not a template function (like it was before) |
| // because passing in constexpr char* as template argument breaks some |
| // versions of MSVC that are being used internally at Meta. |
| // MSVC 14.16.27023 (vs2017_15.9) |
| #define CONCRETE_TRACE_CUDA(func_name, ...) \ |
| at::impl::MaybeSetTLSOnEntryGuard guard; \ |
| if (Py_IsInitialized()) { \ |
| pybind11::gil_scoped_acquire gil; \ |
| try { \ |
| py::module mod = py::module::import("torch.utils._cuda_trace"); \ |
| py::object hook = mod.attr(func_name).attr("fire_callbacks"); \ |
| hook(__VA_ARGS__); \ |
| } catch (const std::exception& e) { \ |
| LOG(ERROR) << "CUDA trace hook execution failed: " << e.what(); \ |
| } \ |
| } |
| |
| struct ConcretePyInterpreterVTable final |
| : public c10::impl::PyInterpreterVTable { |
| std::string name() const override; |
| |
| void decref(PyObject* pyobj, bool is_tensor) const override; |
| |
| // TODO: Need to make this work for StorageImpl too. I imagine I'll want to |
| // operate upon a PyObjectSlot rather than a TensorImpl |
| c10::intrusive_ptr<c10::TensorImpl> detach( |
| const c10::TensorImpl* self) const override; |
| |
| void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) |
| const override; |
| void python_dispatcher( |
| const c10::OperatorHandle& op, |
| c10::DispatchKeySet, |
| torch::jit::Stack* stack) const override; |
| // NB: this is defined in python_dispatch.cpp |
| void python_op_registration_trampoline( |
| const c10::OperatorHandle& op, |
| c10::DispatchKey key, |
| torch::jit::Stack* stack) const override { |
| torch::impl::dispatch::python_op_registration_trampoline_impl( |
| op, key, stack); |
| } |
| |
| bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat) |
| const override; |
| bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat) |
| const override; |
| bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override; |
| c10::Device device(const c10::TensorImpl* self) const override; |
| int64_t dim(const c10::TensorImpl* self) const override; |
| c10::IntArrayRef strides(const c10::TensorImpl* self) const override; |
| c10::IntArrayRef sizes(const c10::TensorImpl* self) const override; |
| c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override; |
| c10::Layout layout(const c10::TensorImpl* self) const override; |
| c10::SymInt sym_numel(const c10::TensorImpl* self) const override; |
| c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override; |
| c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override; |
| |
| void trace_gpu_event_creation(uintptr_t event) const override { |
| CONCRETE_TRACE_CUDA("CUDAEventCreationCallbacks", event); |
| } |
| void trace_gpu_event_deletion(uintptr_t event) const override { |
| CONCRETE_TRACE_CUDA("CUDAEventDeletionCallbacks", event); |
| } |
| void trace_gpu_event_record(uintptr_t event, uintptr_t stream) |
| const override { |
| CONCRETE_TRACE_CUDA("CUDAEventRecordCallbacks", event, stream); |
| } |
| void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) const override { |
| CONCRETE_TRACE_CUDA("CUDAEventWaitCallbacks", event, stream); |
| } |
| void trace_gpu_memory_allocation(uintptr_t ptr) const override { |
| CONCRETE_TRACE_CUDA("CUDAMemoryAllocationCallbacks", ptr); |
| } |
| void trace_gpu_memory_deallocation(uintptr_t ptr) const override { |
| CONCRETE_TRACE_CUDA("CUDAMemoryDeallocationCallbacks", ptr); |
| } |
| void trace_gpu_stream_creation(uintptr_t stream) const override { |
| CONCRETE_TRACE_CUDA("CUDAStreamCreationCallbacks", stream); |
| } |
| void trace_gpu_device_synchronization() const override { |
| CONCRETE_TRACE_CUDA("CUDADeviceSynchronizationCallbacks"); |
| } |
| void trace_gpu_stream_synchronization(uintptr_t stream) const override { |
| CONCRETE_TRACE_CUDA("CUDAStreamSynchronizationCallbacks", stream); |
| } |
| void trace_gpu_event_synchronization(uintptr_t event) const override { |
| CONCRETE_TRACE_CUDA("CUDAEventSynchronizationCallbacks", event); |
| } |
| |
| void reset_backward_hooks(const c10::TensorImpl* self) const override; |
| |
| static ConcretePyInterpreterVTable* instance() { |
| static ConcretePyInterpreterVTable s; |
| return &s; |
| } |
| }; |
| |
| class PyInterpreterHolder { |
| public: |
| PyInterpreterHolder() |
| : impl_(new c10::impl::PyInterpreter( |
| ConcretePyInterpreterVTable::instance())) { |
| is_main_interpreter_ = |
| at::impl::PythonOpRegistrationTrampoline::registerInterpreter(impl_); |
| } |
| // NB: intentionally leaks the PyInterpreter, as there may still be |
| // references to it that are live, living in objects that aren't being |
| // destructed while Python is being cleaned up. |
| ~PyInterpreterHolder() { |
| impl_->disarm(); |
| } |
| c10::impl::PyInterpreter* get() const noexcept { |
| return impl_; |
| } |
| bool is_main_interpreter() const noexcept { |
| return is_main_interpreter_; |
| } |
| |
| private: |
| c10::impl::PyInterpreter* impl_; |
| bool is_main_interpreter_; |
| }; |
| |
| py::object torchDispatchFromTensorImpl( |
| const c10::TensorImpl* self, |
| const char* func_name, |
| PyObject* torch_api_function, |
| const char* module_name, |
| // WARNING: MUST NOT BE TENSOR ARGS |
| c10::SmallVector<py::object, 1> extra_args = {}) { |
| if (torch_api_function == nullptr) { |
| throw python_error(); |
| } |
| TORCH_CHECK( |
| PyGILState_Check(), |
| "GIL must be held before you call parseIValuesToPyArgsKwargs"); |
| |
| std::vector<py::handle> overloaded_args; |
| // TODO: there should be a shorter way to spell this |
| // TODO: fix the constness of target |
| at::Tensor self_t = at::Tensor( |
| c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: |
| unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); |
| auto self_p = |
| py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); |
| // NB: this may not be a python tensor if you got here from a mode! |
| // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); |
| append_overloaded_tensor(&overloaded_args, self_p.ptr()); |
| auto args = |
| py::reinterpret_steal<py::object>(PyTuple_New(1 + extra_args.size())); |
| PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); |
| int64_t i = 1; |
| for (auto& a : extra_args) { |
| if (a.ptr() == nullptr) |
| throw python_error(); |
| PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr()); |
| i++; |
| } |
| |
| py::dict kwargs; |
| |
| return py::reinterpret_steal<py::object>( |
| handle_torch_function_no_python_arg_parser( |
| overloaded_args, |
| args.ptr(), |
| kwargs.ptr(), |
| func_name, |
| torch_api_function, |
| module_name, |
| TorchFunctionName::TorchDispatch)); |
| } |
| |
| // NOTE [PyInterpreter::decref takes an `is_tensor` arg] |
| // Before calling PyInterpreter::decref, we must statically know if the |
| // pyobj is a Tensor or not. |
| // - If it is a tensor, we need to be careful about PyObject resurrection |
| // - If it is not a tensor, we can freely decref |
| // One alternative to this is using PyObject_IsInstance |
| // to get at this information. However, we don't want to risk an incorrect |
| // `__instancecheck__` changing the semantics here. |
| void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool is_tensor) |
| const { |
| // Leak the pyobj if not initialized. This can happen if we are running |
| // exit handlers that are destructing tensors with residual (owned) |
| // PyObjects stored in them. |
| if (!Py_IsInitialized()) |
| return; |
| |
| pybind11::gil_scoped_acquire gil; |
| // Two possibilities: |
| // 1. We are decref-ing a tensor. Then we must be careful about |
| // PyObject resurrection (this only applies to Tensors, see |
| // THPVariable_clear). |
| // 2. We are decref-ing some other Python object. We don't do |
| // PyObject resurrection on non-Tensors, so we just carry on as usual |
| if (is_tensor && Py_REFCNT(pyobj) > 1) { |
| // It's still alive! This can happen if a weak ref resurrected |
| // the PyObject without flipping ownership. At this point it is |
| // too late to rescue the object, so just stub out the PyObject |
| // so that it fails on subsequent uses. Don't raise an error here; |
| // you're probably in a destructor. |
| TORCH_WARN( |
| "Deallocating Tensor that still has live PyObject references. " |
| "This probably happened because you took out a weak reference to " |
| "Tensor and didn't call _fix_weakref() after dereferencing it. " |
| "Subsequent accesses to this tensor via the PyObject will now fail."); |
| ((THPVariable*)pyobj)->cdata = c10::MaybeOwned<torch::autograd::Variable>(); |
| } |
| Py_DECREF(pyobj); |
| }; |
| |
| py::handle getTorchApiFunction(const c10::OperatorHandle& op) { |
| return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* { |
| // Parse the name into namespace and name (no overload_name) |
| // TODO: put this into the library |
| const auto& schema = op.schema(); |
| const auto& qualified_name = op.operator_name().name; |
| const auto& overload_name = schema.overload_name(); |
| auto pos = qualified_name.find("::"); |
| TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); |
| // Make me some null terminated strings |
| std::string ns_str = qualified_name.substr(0, pos); |
| const char* ns = ns_str.c_str(); |
| const char* func_name = qualified_name.c_str() + pos + strlen("::"); |
| |
| py::handle torch_api_function = |
| py::module::import("torch").attr("ops").attr(ns).attr(func_name); |
| if (overload_name.empty()) { |
| return torch_api_function.attr("default").ptr(); |
| } else { |
| return torch_api_function.attr(overload_name.c_str()).ptr(); |
| } |
| }); |
| } |
| |
| bool isPythonTensor(const at::Tensor& tensor) { |
| return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); |
| } |
| |
| void ConcretePyInterpreterVTable::dispatch( |
| const c10::OperatorHandle& op, |
| torch::jit::Stack* stack) const { |
| const auto& schema = op.schema(); |
| const auto num_arguments = schema.arguments().size(); |
| auto arguments = torch::jit::pop(*stack, num_arguments); |
| |
| // The plan: convert all the arguments back into PyObjects, |
| // extracting out the tensor handles, then call |
| // handle_torch_function_no_python_arg_parser |
| // NB: at the point arguments are pushed to the stack, ALL defaults |
| // are already present |
| |
| py::gil_scoped_acquire g; |
| |
| std::vector<py::handle> overloaded_args; |
| py::handle torch_api_function_overload = getTorchApiFunction(op); |
| |
| // Find overloaded tensors |
| for (const auto idx : c10::irange(arguments.size())) { |
| const auto& ivalue = arguments[idx]; |
| if (ivalue.isTensor()) { |
| const auto& tensor = ivalue.toTensor(); |
| if (isPythonTensor(tensor)) { |
| append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); |
| } |
| } else if (ivalue.isList()) { |
| const auto& list = ivalue.toListRef(); |
| for (const auto jdx : c10::irange(list.size())) { |
| const auto& nv = list[jdx]; |
| if (nv.isTensor()) { |
| const auto& tensor = nv.toTensor(); |
| if (isPythonTensor(tensor)) { |
| append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); |
| } |
| } |
| } |
| } |
| } |
| |
| auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
| auto args = std::move(args_kwargs.first); |
| auto kwargs = std::move(args_kwargs.second); |
| |
| PyObject* obj = handle_torch_function_no_python_arg_parser( |
| overloaded_args, |
| args.ptr(), |
| kwargs.ptr(), |
| nullptr, |
| torch_api_function_overload.ptr(), |
| nullptr, |
| TorchFunctionName::TorchDispatch); |
| pushPyOutToStack( |
| op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__"); |
| } |
| |
| void ConcretePyInterpreterVTable::python_dispatcher( |
| const c10::OperatorHandle& op, |
| c10::DispatchKeySet ks, |
| torch::jit::Stack* stack) const { |
| py::gil_scoped_acquire g; |
| py::handle torch_api_function_overload = getTorchApiFunction(op); |
| // TODO: if necessary, can optimize to cache the cache lookup |
| // TODO: if necessary, can optimize OpOverload to have slots |
| auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache")); |
| if (cache.ptr() == nullptr) { |
| throw python_error(); |
| } |
| |
| c10::DispatchKey k = ks.highestPriorityTypeId(); |
| // TODO: allow this to be non-owning |
| auto handler = py::reinterpret_borrow<py::object>( |
| PyDict_GetItem(cache.ptr(), py::cast(k).ptr())); |
| if (handler.ptr() == nullptr) { |
| // Slow path |
| handler = torch_api_function_overload.attr("_get_dispatch")(k); |
| } |
| if (py::isinstance<c10::DispatchKey>(handler)) { |
| // NB: not redispatch, as that will permanently remove the python |
| // dispatcher for subsequent redispatches |
| op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack); |
| return; |
| } |
| |
| const auto& schema = op.schema(); |
| const auto num_arguments = schema.arguments().size(); |
| auto arguments = torch::jit::pop(*stack, num_arguments); |
| |
| auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
| auto args = std::move(args_kwargs.first); |
| auto kwargs = std::move(args_kwargs.second); |
| |
| py::object obj = py::reinterpret_steal<py::object>( |
| PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr())); |
| |
| if (obj.ptr() == nullptr) { |
| throw python_error(); |
| } |
| |
| pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher"); |
| } |
| |
| c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "detach", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("detach") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| TORCH_CHECK( |
| THPVariable_Check(out.ptr()), |
| "detach returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected Tensor"); |
| const at::Tensor& res_t = THPVariable_Unpack(out.ptr()); |
| return res_t.getIntrusivePtr(); |
| } |
| |
| bool ConcretePyInterpreterVTable::is_contiguous( |
| const c10::TensorImpl* self, |
| at::MemoryFormat memory_format) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| py::object out; |
| if (memory_format == at::MemoryFormat::Contiguous) { |
| // For backwards compatibility |
| out = torchDispatchFromTensorImpl( |
| self, |
| "is_contiguous", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("is_contiguous") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| } else { |
| out = torchDispatchFromTensorImpl( |
| self, |
| "is_contiguous", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("is_contiguous") |
| .attr("memory_format") |
| .ptr(), |
| "torch.ops.aten", |
| {py::cast(memory_format)}); |
| } |
| |
| if (out.is_none()) { |
| return self->is_contiguous_default(memory_format); |
| } |
| |
| TORCH_CHECK( |
| PyBool_Check(out.ptr()), |
| "is_contiguous returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected bool"); |
| |
| return PyObject_IsTrue(out.ptr()); |
| } |
| |
| bool ConcretePyInterpreterVTable::is_strides_like( |
| const c10::TensorImpl* self, |
| at::MemoryFormat memory_format) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "is_strides_like", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| // NB: intentionally suffixed with _format to avoid |
| // triggering matches against "_like" suffix |
| .attr("is_strides_like_format") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten", |
| {py::cast(memory_format)}); |
| |
| if (out.is_none()) { |
| return self->is_strides_like_default(memory_format); |
| } |
| |
| TORCH_CHECK( |
| PyBool_Check(out.ptr()), |
| "is_strides_like_format returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected bool"); |
| |
| return PyObject_IsTrue(out.ptr()); |
| } |
| |
| bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "is_non_overlapping_and_dense", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("is_non_overlapping_and_dense") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->is_non_overlapping_and_dense_default(); |
| } |
| |
| TORCH_CHECK( |
| PyBool_Check(out.ptr()), |
| "is_non_overlapping_and_dense returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected bool"); |
| |
| return PyObject_IsTrue(out.ptr()); |
| } |
| |
| int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "dim", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("dim") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| TORCH_CHECK( |
| PyLong_Check(out.ptr()), |
| "dim returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected int"); |
| |
| return THPUtils_unpackLong(out.ptr()); |
| } |
| |
| c10::Device ConcretePyInterpreterVTable::device( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "device", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("prim") |
| .attr("device") |
| .attr("default") |
| .ptr(), |
| "torch.ops.prim"); |
| |
| return toDevice(out.ptr()); |
| } |
| |
| c10::IntArrayRef ConcretePyInterpreterVTable::strides( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "stride", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("stride") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| TORCH_CHECK( |
| !self->has_symbolic_sizes_strides(), |
| "Cannot call strides on a tensor with symbolic shapes/strides"); |
| return self->strides_default(); |
| } |
| |
| py::object values = py::reinterpret_steal<py::object>(out.ptr()); |
| |
| c10::optional<PyObject*> mb_obj = |
| self->pyobj_slot()->check_pyobj(getPyInterpreter()); |
| TORCH_CHECK( |
| mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); |
| PyObject* subclass = *mb_obj; |
| Py_INCREF(subclass); |
| py::object sub = py::reinterpret_steal<py::object>(subclass); |
| |
| py::object os = py::module_::import("torch").attr("overrides"); |
| py::function get_buffer = |
| py::reinterpret_borrow<py::function>(os.attr("get_buffer")); |
| auto buffer = get_buffer(sub, values, "stride"); |
| auto result = THPUtils_unpackLongs(buffer.ptr()); |
| int64_t* start = (int64_t*)result[0]; |
| int64_t len = result[1]; |
| |
| return c10::IntArrayRef(start, len); |
| } |
| |
| static std::vector<int64_t> values_from_buffer( |
| const c10::TensorImpl* self, |
| py::handle values) { |
| c10::TensorImpl* ptr = const_cast<c10::TensorImpl*>(self); |
| c10::optional<PyObject*> mb_obj = |
| ptr->pyobj_slot()->check_pyobj(getPyInterpreter()); |
| TORCH_CHECK( |
| mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); |
| |
| py::object os = py::module_::import("torch").attr("overrides"); |
| py::function get_buffer = |
| py::reinterpret_borrow<py::function>(os.attr("get_buffer")); |
| auto buffer = get_buffer(py::handle(*mb_obj), values, "size"); |
| auto result = THPUtils_unpackLongs(buffer.ptr()); |
| return result; |
| } |
| |
| c10::IntArrayRef ConcretePyInterpreterVTable::sizes( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "size", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("size") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| TORCH_CHECK( |
| !self->has_symbolic_sizes_strides(), |
| "Cannot call sizes on a tensor with symbolic shapes/strides"); |
| return self->sizes_default(); |
| } |
| |
| py::object values = py::reinterpret_steal<py::object>(out.ptr()); |
| auto result = values_from_buffer(self, values); |
| int64_t* start = (int64_t*)result[0]; |
| int64_t len = result[1]; |
| |
| return c10::IntArrayRef(start, len); |
| } |
| |
| c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| HANDLE_TH_ERRORS |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_size", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_size") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->sym_sizes_default(); |
| } |
| // We need to squeeze SymIntNodes and ints into `SymInts` |
| // since it's a format `sym_sizes()` are stored in |
| TORCH_CHECK( |
| py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
| "Symshape must be a list or a tuple"); |
| py::list symints; |
| for (auto it = out.begin(); it != out.end(); it++) { |
| auto elm = *it; |
| auto si = py::cast<c10::SymInt>(elm); |
| // TODO: the buffer will need to be made owning later |
| symints.append(si.as_int_unchecked()); |
| } |
| |
| auto result = values_from_buffer(self, symints); |
| c10::SymInt* start = (c10::SymInt*)result[0]; |
| int64_t len = result[1]; |
| |
| return c10::SymIntArrayRef(start, len); |
| END_HANDLE_TH_ERRORS_PYBIND |
| } |
| |
| c10::Layout ConcretePyInterpreterVTable::layout( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "layout", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("prim") |
| .attr("layout") |
| .attr("default") |
| .ptr(), |
| "torch.ops.prim"); |
| |
| TORCH_CHECK( |
| THPLayout_Check(out.ptr()), |
| "layout returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected Layout"); |
| |
| return toLayout(out.ptr()); |
| } |
| |
| c10::SymInt ConcretePyInterpreterVTable::sym_numel( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_numel", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_numel") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| TORCH_CHECK( |
| !self->has_symbolic_sizes_strides(), |
| "Cannot call numel on a tensor with symbolic shapes/strides"); |
| return self->sym_numel_default(); |
| } |
| return torch::is_symint(out) ? out.cast<c10::SymInt>() |
| : c10::SymInt{py::cast<int64_t>(out)}; |
| } |
| |
| c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_storage_offset", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_storage_offset") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->sym_storage_offset_default(); |
| } |
| return torch::is_symint(out) ? out.cast<c10::SymInt>() |
| : c10::SymInt{py::cast<int64_t>(out)}; |
| } |
| |
| c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| HANDLE_TH_ERRORS |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_stride", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_stride") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->sym_strides_default(); |
| } |
| // We need to squeeze SymIntNodes and ints into `SymInts` |
| // since it's a format `sym_strides()` are stored in |
| TORCH_CHECK( |
| py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
| "Symshape must be a list or a tuple"); |
| py::list symints; |
| for (auto it = out.begin(); it != out.end(); it++) { |
| auto elm = *it; |
| auto si = torch::is_symint(elm) ? elm.cast<c10::SymInt>() |
| : c10::SymInt{py::cast<int64_t>(elm)}; |
| symints.append(si.as_int_unchecked()); |
| } |
| |
| auto result = values_from_buffer(self, symints); |
| c10::SymInt* start = (c10::SymInt*)result[0]; |
| int64_t len = result[1]; |
| |
| return c10::SymIntArrayRef(start, len); |
| END_HANDLE_TH_ERRORS_PYBIND |
| } |
| |
| PyInterpreterHolder self_interpreter; |
| |
| void ConcretePyInterpreterVTable::reset_backward_hooks( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| HANDLE_TH_ERRORS |
| Tensor self_t = Tensor( |
| c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: |
| unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); |
| auto self_p = |
| py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); |
| PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None); |
| END_HANDLE_TH_ERRORS_PYBIND |
| } |
| |
| } // anonymous namespace |
| |
| c10::impl::PyInterpreter* getPyInterpreter() { |
| return self_interpreter.get(); |
| } |
| |
| bool isMainPyInterpreter() { |
| return self_interpreter.is_main_interpreter(); |
| } |
| |
| std::string ConcretePyInterpreterVTable::name() const { |
| std::stringstream ss; |
| ss << getPyInterpreter(); |
| return ss.str(); |
| } |