| #include <ATen/core/PythonFallbackKernel.h> |
| #include <ATen/core/PythonOpRegistrationTrampoline.h> |
| #include <torch/csrc/PyInterpreter.h> |
| #include <torch/csrc/THP.h> |
| #include <torch/csrc/autograd/generated/VariableType.h> |
| #include <torch/csrc/utils/python_arg_parser.h> |
| #include <torch/csrc/utils/python_dispatch.h> |
| |
| #include <string> |
| |
| using namespace torch; |
| using namespace at; |
| using namespace c10; |
| |
| namespace { |
| |
| // NB: This is a macro and not a template function (like it was before) |
| // because passing in constexpr char* as template argument breaks some |
| // versions of MSVC that are being used internally at Meta. |
| // MSVC 14.16.27023 (vs2017_15.9) |
| #define CONCRETE_GPU_TRACE(device_type, func_name, ...) \ |
| at::impl::MaybeSetTLSOnEntryGuard guard; \ |
| if (Py_IsInitialized()) { \ |
| pybind11::gil_scoped_acquire gil; \ |
| try { \ |
| /* Masquerade hip as cuda because hip uses `torch.cuda` module. */ \ |
| if (device_type == at::kHIP) { \ |
| device_type = at::kCUDA; \ |
| } \ |
| std::string module_name = "torch." + DeviceTypeName(device_type, true); \ |
| py::module mod = py::module::import(module_name.c_str()); \ |
| py::object hook = \ |
| mod.attr("_gpu_trace").attr(func_name).attr("fire_callbacks"); \ |
| hook(__VA_ARGS__); \ |
| } catch (const std::exception& e) { \ |
| LOG(ERROR) << device_type \ |
| << " trace hook execution failed: " << e.what(); \ |
| } \ |
| } |
| |
| struct ConcretePyInterpreterVTable final |
| : public c10::impl::PyInterpreterVTable { |
| std::string name() const override; |
| |
| void decref(PyObject* pyobj, bool has_pyobj_slot) const override; |
| |
| // TODO: Need to make this work for StorageImpl too. I imagine I'll want to |
| // operate upon a PyObjectSlot rather than a TensorImpl |
| c10::intrusive_ptr<c10::TensorImpl> detach( |
| const c10::TensorImpl* self) const override; |
| |
| void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) |
| const override; |
| void reportErrorCallback(PyObject* callback, DispatchKey key) const override; |
| void python_dispatcher( |
| const c10::OperatorHandle& op, |
| c10::DispatchKeySet, |
| torch::jit::Stack* stack) const override; |
| // NB: this is defined in python_dispatch.cpp |
| void python_op_registration_trampoline( |
| const c10::OperatorHandle& op, |
| c10::DispatchKey key, |
| c10::DispatchKeySet keyset, |
| torch::jit::Stack* stack, |
| bool with_keyset) const override { |
| torch::impl::dispatch::python_op_registration_trampoline_impl( |
| op, key, keyset, stack, with_keyset); |
| } |
| void throw_abstract_impl_not_imported_error( |
| std::string opname, |
| const char* pymodule, |
| const char* context) const override { |
| py::gil_scoped_acquire gil; |
| pybind11::module::import("torch._utils_internal") |
| .attr("throw_abstract_impl_not_imported_error")( |
| opname, pymodule, context); |
| } |
| |
| bool is_contiguous(const c10::TensorImpl* self, at::MemoryFormat) |
| const override; |
| bool is_strides_like(const c10::TensorImpl* self, at::MemoryFormat) |
| const override; |
| bool is_non_overlapping_and_dense(const c10::TensorImpl* self) const override; |
| c10::Device device(const c10::TensorImpl* self) const override; |
| int64_t dim(const c10::TensorImpl* self) const override; |
| c10::IntArrayRef strides(const c10::TensorImpl* self) const override; |
| c10::IntArrayRef sizes(const c10::TensorImpl* self) const override; |
| c10::SymIntArrayRef sym_sizes(const c10::TensorImpl* self) const override; |
| c10::Layout layout(const c10::TensorImpl* self) const override; |
| int64_t numel(const c10::TensorImpl* self) const override; |
| c10::SymInt sym_numel(const c10::TensorImpl* self) const override; |
| c10::SymIntArrayRef sym_strides(const c10::TensorImpl* self) const override; |
| c10::SymInt sym_storage_offset(const c10::TensorImpl* self) const override; |
| |
| void trace_gpu_event_creation(at::DeviceType device_type, uintptr_t event) |
| const override { |
| CONCRETE_GPU_TRACE(device_type, "EventCreationCallbacks", event); |
| } |
| void trace_gpu_event_deletion(at::DeviceType device_type, uintptr_t event) |
| const override { |
| CONCRETE_GPU_TRACE(device_type, "EventDeletionCallbacks", event); |
| } |
| void trace_gpu_event_record( |
| at::DeviceType device_type, |
| uintptr_t event, |
| uintptr_t stream) const override { |
| CONCRETE_GPU_TRACE(device_type, "EventRecordCallbacks", event, stream); |
| } |
| void trace_gpu_event_wait( |
| at::DeviceType device_type, |
| uintptr_t event, |
| uintptr_t stream) const override { |
| CONCRETE_GPU_TRACE(device_type, "EventWaitCallbacks", event, stream); |
| } |
| void trace_gpu_memory_allocation(at::DeviceType device_type, uintptr_t ptr) |
| const override { |
| CONCRETE_GPU_TRACE(device_type, "MemoryAllocationCallbacks", ptr); |
| } |
| void trace_gpu_memory_deallocation(at::DeviceType device_type, uintptr_t ptr) |
| const override { |
| CONCRETE_GPU_TRACE(device_type, "MemoryDeallocationCallbacks", ptr); |
| } |
| void trace_gpu_stream_creation(at::DeviceType device_type, uintptr_t stream) |
| const override { |
| CONCRETE_GPU_TRACE(device_type, "StreamCreationCallbacks", stream); |
| } |
| void trace_gpu_device_synchronization( |
| at::DeviceType device_type) const override { |
| CONCRETE_GPU_TRACE(device_type, "DeviceSynchronizationCallbacks"); |
| } |
| void trace_gpu_stream_synchronization( |
| at::DeviceType device_type, |
| uintptr_t stream) const override { |
| CONCRETE_GPU_TRACE(device_type, "StreamSynchronizationCallbacks", stream); |
| } |
| void trace_gpu_event_synchronization( |
| at::DeviceType device_type, |
| uintptr_t event) const override { |
| CONCRETE_GPU_TRACE(device_type, "EventSynchronizationCallbacks", event); |
| } |
| |
| void reset_backward_hooks(const c10::TensorImpl* self) const override; |
| |
| static ConcretePyInterpreterVTable* instance() { |
| static ConcretePyInterpreterVTable s; |
| return &s; |
| } |
| }; |
| |
| class PyInterpreterHolder { |
| public: |
| PyInterpreterHolder() |
| : impl_(new c10::impl::PyInterpreter( |
| ConcretePyInterpreterVTable::instance())), |
| is_main_interpreter_( |
| at::impl::PythonOpRegistrationTrampoline::registerInterpreter( |
| impl_)) {} |
| // NB: intentionally leaks the PyInterpreter, as there may still be |
| // references to it that are live, living in objects that aren't being |
| // destructed while Python is being cleaned up. |
| ~PyInterpreterHolder() { |
| impl_->disarm(); |
| } |
| c10::impl::PyInterpreter* get() const noexcept { |
| return impl_; |
| } |
| bool is_main_interpreter() const noexcept { |
| return is_main_interpreter_; |
| } |
| |
| private: |
| c10::impl::PyInterpreter* impl_; |
| bool is_main_interpreter_; |
| }; |
| |
| py::object torchDispatchFromTensorImpl( |
| const c10::TensorImpl* self, |
| const char* func_name, |
| PyObject* torch_api_function, |
| const char* module_name, |
| // WARNING: MUST NOT BE TENSOR ARGS |
| c10::SmallVector<py::object, 1> extra_args = {}) { |
| if (torch_api_function == nullptr) { |
| throw python_error(); |
| } |
| TORCH_CHECK( |
| PyGILState_Check(), |
| "GIL must be held before you call parseIValuesToPyArgsKwargs"); |
| |
| std::vector<PyObject*> overloaded_args; |
| // TODO: there should be a shorter way to spell this |
| // TODO: fix the constness of target |
| at::Tensor self_t = at::Tensor( |
| c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
| unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); |
| auto self_p = |
| py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); |
| // NB: this may not be a python tensor if you got here from a mode! |
| // TORCH_INTERNAL_ASSERT(isPythonTensor(self_t)); |
| append_overloaded_tensor(&overloaded_args, self_p.ptr()); |
| auto args = py::reinterpret_steal<py::object>( |
| PyTuple_New(static_cast<Py_ssize_t>(1 + extra_args.size()))); |
| PyTuple_SET_ITEM(args.ptr(), 0, self_p.release().ptr()); |
| int64_t i = 1; |
| for (auto& a : extra_args) { |
| if (a.ptr() == nullptr) |
| throw python_error(); |
| PyTuple_SET_ITEM(args.ptr(), i, std::move(a).release().ptr()); |
| i++; |
| } |
| |
| py::dict kwargs; |
| |
| return py::reinterpret_steal<py::object>( |
| handle_torch_function_no_python_arg_parser( |
| overloaded_args, |
| args.ptr(), |
| kwargs.ptr(), |
| func_name, |
| torch_api_function, |
| module_name, |
| TorchFunctionName::TorchDispatch)); |
| } |
| |
| // NOTE [PyInterpreter::decref takes a `has_pyobj_slot` arg] |
| // Before calling PyInterpreter::decref, we must statically know if the |
| // pyobj has a PyObjectSlot or not. |
| // - If it has a PyObjectSlot, we need to be careful about PyObject resurrection |
| // - If it does not have a PyObjectSlot, we can freely decref |
| // One alternative to this is using PyObject_IsInstance |
| // to get at this information. However, we don't want to risk an incorrect |
| // `__instancecheck__` changing the semantics here. |
| void ConcretePyInterpreterVTable::decref(PyObject* pyobj, bool has_pyobj_slot) |
| const { |
| // Leak the pyobj if not initialized. This can happen if we are running |
| // exit handlers that are destructing tensors with residual (owned) |
| // PyObjects stored in them. |
| if (!Py_IsInitialized()) |
| return; |
| |
| pybind11::gil_scoped_acquire gil; |
| // Two possibilities: |
| // 1. We are decref-ing an object that has a PyObjectSlot, like a Tensor or |
| // Storage. Then we must be careful about PyObject resurrection (see |
| // THPVariable_clear). |
| // 2. We are decref-ing some other Python object. We don't do |
| // PyObject resurrection on non-Tensors, so we just carry on as usual |
| if (has_pyobj_slot && Py_REFCNT(pyobj) > 1) { |
| if (THPVariable_Check(pyobj)) { |
| // It's still alive! This can happen if a weak ref resurrected |
| // the PyObject without flipping ownership. At this point it is |
| // too late to rescue the object, so just stub out the PyObject |
| // so that it fails on subsequent uses. Don't raise an error here; |
| // you're probably in a destructor. |
| TORCH_WARN( |
| "Deallocating Tensor that still has live PyObject references. " |
| "This probably happened because you took out a weak reference to " |
| "Tensor and didn't call _fix_weakref() after dereferencing it. " |
| "Subsequent accesses to this tensor via the PyObject will now fail."); |
| ((THPVariable*)pyobj)->cdata = |
| c10::MaybeOwned<torch::autograd::Variable>(); |
| } else if (THPStorage_Check(pyobj)) { |
| TORCH_WARN( |
| "Deallocating UntypedStorage that still has live PyObject references. " |
| "This probably happened because you took out a weak reference to " |
| "UntypedStorage and didn't call _fix_weakref() after dereferencing it. " |
| "Subsequent accesses to this storage via the PyObject will now fail."); |
| ((THPStorage*)pyobj)->cdata = c10::MaybeOwned<c10::Storage>(); |
| } |
| } |
| Py_DECREF(pyobj); |
| }; |
| |
| py::handle getTorchApiFunction(const c10::OperatorHandle& op) { |
| return op.getPythonOp(getPyInterpreter(), [&]() -> PyObject* { |
| // Parse the name into namespace and name (no overload_name) |
| // TODO: put this into the library |
| const auto& schema = op.schema(); |
| const auto& qualified_name = op.operator_name().name; |
| const auto& overload_name = schema.overload_name(); |
| auto pos = qualified_name.find("::"); |
| TORCH_INTERNAL_ASSERT(pos != std::string::npos, qualified_name); |
| // Make me some null terminated strings |
| std::string ns_str = qualified_name.substr(0, pos); |
| const char* ns = ns_str.c_str(); |
| const char* func_name = qualified_name.c_str() + pos + strlen("::"); |
| |
| py::handle torch_api_function = |
| py::module::import("torch").attr("ops").attr(ns).attr(func_name); |
| if (overload_name.empty()) { |
| return torch_api_function.attr("default").ptr(); |
| } else { |
| return torch_api_function.attr(overload_name.c_str()).ptr(); |
| } |
| }); |
| } |
| |
| bool isPythonTensor(const at::Tensor& tensor) { |
| return tensor.unsafeGetTensorImpl()->key_set().has(c10::DispatchKey::Python); |
| } |
| |
| void ConcretePyInterpreterVTable::reportErrorCallback( |
| PyObject* callback, |
| DispatchKey key) const { |
| py::gil_scoped_acquire g; |
| auto func = py::reinterpret_borrow<py::object>(callback); |
| // Not all DispatchKeys are pybind'ed into Python and we do not have infra |
| // to ensure this, so just pass a string back to Python. |
| func(c10::toString(key)); |
| } |
| |
| void ConcretePyInterpreterVTable::dispatch( |
| const c10::OperatorHandle& op, |
| torch::jit::Stack* stack) const { |
| const auto& schema = op.schema(); |
| const auto num_arguments = schema.arguments().size(); |
| auto arguments = torch::jit::pop(*stack, num_arguments); |
| |
| // The plan: convert all the arguments back into PyObjects, |
| // extracting out the tensor handles, then call |
| // handle_torch_function_no_python_arg_parser |
| // NB: at the point arguments are pushed to the stack, ALL defaults |
| // are already present |
| |
| py::gil_scoped_acquire g; |
| |
| std::vector<PyObject*> overloaded_args; |
| py::handle torch_api_function_overload = getTorchApiFunction(op); |
| |
| // Find overloaded tensors |
| for (const auto idx : c10::irange(arguments.size())) { |
| const auto& ivalue = arguments[idx]; |
| if (ivalue.isTensor()) { |
| const auto& tensor = ivalue.toTensor(); |
| if (isPythonTensor(tensor)) { |
| append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); |
| } |
| } else if (ivalue.isList()) { |
| const auto& list = ivalue.toListRef(); |
| for (const auto jdx : c10::irange(list.size())) { |
| const auto& nv = list[jdx]; |
| if (nv.isTensor()) { |
| const auto& tensor = nv.toTensor(); |
| if (isPythonTensor(tensor)) { |
| append_overloaded_tensor(&overloaded_args, py::cast(tensor).ptr()); |
| } |
| } |
| } |
| } |
| } |
| |
| auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
| auto args = std::move(args_kwargs.first); |
| auto kwargs = std::move(args_kwargs.second); |
| |
| PyObject* obj = handle_torch_function_no_python_arg_parser( |
| overloaded_args, |
| args.ptr(), |
| kwargs.ptr(), |
| nullptr, |
| torch_api_function_overload.ptr(), |
| nullptr, |
| TorchFunctionName::TorchDispatch); |
| pushPyOutToStack( |
| op, stack, py::reinterpret_steal<py::object>(obj), "__torch_dispatch__"); |
| } |
| |
| void ConcretePyInterpreterVTable::python_dispatcher( |
| const c10::OperatorHandle& op, |
| c10::DispatchKeySet ks, |
| torch::jit::Stack* stack) const { |
| py::gil_scoped_acquire g; |
| py::handle torch_api_function_overload = getTorchApiFunction(op); |
| // TODO: if necessary, can optimize to cache the cache lookup |
| // TODO: if necessary, can optimize OpOverload to have slots |
| auto cache = py::dict(torch_api_function_overload.attr("_dispatch_cache")); |
| if (cache.ptr() == nullptr) { |
| throw python_error(); |
| } |
| |
| c10::DispatchKey k = ks.highestPriorityTypeId(); |
| // TODO: allow this to be non-owning |
| auto handler = py::reinterpret_borrow<py::object>( |
| PyDict_GetItem(cache.ptr(), py::cast(k).ptr())); |
| if (handler.ptr() == nullptr) { |
| // Slow path |
| handler = torch_api_function_overload.attr("_get_dispatch")(k); |
| } |
| if (py::isinstance<c10::DispatchKey>(handler)) { |
| // NB: not redispatch, as that will permanently remove the python |
| // dispatcher for subsequent redispatches |
| op.callBoxedForDispatchKey(py::cast<c10::DispatchKey>(handler), *stack); |
| return; |
| } |
| |
| const auto& schema = op.schema(); |
| const auto num_arguments = schema.arguments().size(); |
| auto arguments = torch::jit::pop(*stack, num_arguments); |
| |
| auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments); |
| auto args = std::move(args_kwargs.first); |
| auto kwargs = std::move(args_kwargs.second); |
| |
| py::object obj = py::reinterpret_steal<py::object>( |
| PyObject_Call(handler.ptr(), args.ptr(), kwargs.ptr())); |
| |
| if (obj.ptr() == nullptr) { |
| throw python_error(); |
| } |
| |
| pushPyOutToStack(op, stack, std::move(obj), "Python dispatcher"); |
| } |
| |
| c10::intrusive_ptr<c10::TensorImpl> ConcretePyInterpreterVTable::detach( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "detach", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("detach") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| TORCH_CHECK( |
| THPVariable_Check(out.ptr()), |
| "detach returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected Tensor"); |
| const at::Tensor& res_t = THPVariable_Unpack(out.ptr()); |
| return res_t.getIntrusivePtr(); |
| } |
| |
| bool ConcretePyInterpreterVTable::is_contiguous( |
| const c10::TensorImpl* self, |
| at::MemoryFormat memory_format) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| py::object out; |
| if (memory_format == at::MemoryFormat::Contiguous) { |
| // For backwards compatibility |
| out = torchDispatchFromTensorImpl( |
| self, |
| "is_contiguous", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("is_contiguous") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| } else { |
| out = torchDispatchFromTensorImpl( |
| self, |
| "is_contiguous", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("is_contiguous") |
| .attr("memory_format") |
| .ptr(), |
| "torch.ops.aten", |
| {py::cast(memory_format)}); |
| } |
| |
| if (out.is_none()) { |
| return self->is_contiguous_default(memory_format); |
| } |
| |
| TORCH_CHECK( |
| PyBool_Check(out.ptr()), |
| "is_contiguous returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected bool"); |
| |
| return PyObject_IsTrue(out.ptr()); |
| } |
| |
| bool ConcretePyInterpreterVTable::is_strides_like( |
| const c10::TensorImpl* self, |
| at::MemoryFormat memory_format) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "is_strides_like", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| // NB: intentionally suffixed with _format to avoid |
| // triggering matches against "_like" suffix |
| .attr("is_strides_like_format") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten", |
| {py::cast(memory_format)}); |
| |
| if (out.is_none()) { |
| return self->is_strides_like_default(memory_format); |
| } |
| |
| TORCH_CHECK( |
| PyBool_Check(out.ptr()), |
| "is_strides_like_format returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected bool"); |
| |
| return PyObject_IsTrue(out.ptr()); |
| } |
| |
| bool ConcretePyInterpreterVTable::is_non_overlapping_and_dense( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "is_non_overlapping_and_dense", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("is_non_overlapping_and_dense") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->is_non_overlapping_and_dense_default(); |
| } |
| |
| TORCH_CHECK( |
| PyBool_Check(out.ptr()), |
| "is_non_overlapping_and_dense returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected bool"); |
| |
| return PyObject_IsTrue(out.ptr()); |
| } |
| |
| int64_t ConcretePyInterpreterVTable::dim(const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "dim", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("dim") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| TORCH_CHECK( |
| PyLong_Check(out.ptr()), |
| "dim returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected int"); |
| |
| return THPUtils_unpackLong(out.ptr()); |
| } |
| |
| c10::Device ConcretePyInterpreterVTable::device( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "device", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("prim") |
| .attr("device") |
| .attr("default") |
| .ptr(), |
| "torch.ops.prim"); |
| |
| return toDevice(out.ptr()); |
| } |
| |
| static void set_tensor_attr_with_capsule( |
| const c10::TensorImpl* tensor, |
| py::capsule& capsule, |
| const char* attr_name) { |
| std::optional<PyObject*> mb_obj = tensor->pyobj_slot()->check_pyobj( |
| getPyInterpreter(), /*ignore_hermetic_tls=*/false); |
| TORCH_CHECK( |
| mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); |
| auto obj = mb_obj.value(); |
| py::handle(obj).attr(attr_name) = capsule; |
| } |
| |
| // Note [Tensor Subclass custom size/stride caching strategy] |
| // Tensor subclasses can use __torch_dispatch__ to override size/stride calls. |
| // However, this presents a problem: |
| // (1) When you return a custom (maybe symbolic) size/stride |
| // from python, we need to stash this fresh vector of ints/symints |
| // somewhere so that it has the same lifetime as the tensor. |
| // (2) If the subclass experiences a metadata mutation, |
| // this stashed vector is no longer valid, so we need to allocate a fresh |
| // buffer to store the new sizes the next time someone asks for them. |
| // |
| // We handle this in the same way that `TensorImpl::sizes_default()` |
| // handles its buffer: we simply reallocate the buffer whenever |
| // the number of dimensions changes due to a resize. |
| // Notable, we do *not* reallocate the buffer if the values changed, |
| // but the number of dimensions stayed the same (e.g. `.transpose_()`). |
| template <typename T> |
| static c10::ArrayRef<T> get_set_cached_attr( |
| const c10::TensorImpl* tensor, |
| const char* base_attr_name, |
| const py::object& obj) { |
| std::optional<PyObject*> mb_obj = |
| tensor->pyobj_slot()->check_pyobj(getPyInterpreter()); |
| TORCH_CHECK( |
| mb_obj.has_value(), "Tensor subclass's PyInterpreter has no value"); |
| auto tensor_obj = mb_obj.value(); |
| auto buffer_len_attr_name = std::string(base_attr_name) + std::string("_len"); |
| |
| bool is_buffer_allocated = false; |
| size_t curr_size = 0; |
| if (PyObject_HasAttrString(tensor_obj, buffer_len_attr_name.c_str())) { |
| auto len_pyobj = py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()); |
| curr_size = py::cast<size_t>(len_pyobj); |
| is_buffer_allocated = true; |
| } |
| |
| size_t new_size = py::len(obj); |
| |
| // We do the smallvector optimization here: any time the new_size is <=5, |
| // we always allocate our buffer to size 5, so that if the next resize |
| // is also to <=5 elements, we don't need to reallocate. |
| // Note: I tried removing this optimization and tripped ASAN |
| // in a batchnorm kernel here: |
| // https://pipelinesghubeus21.actions.githubusercontent.com/mBh68xKhi8LyM7tp3vECvYXNFvuV4gyVGgmYCteuEZP9JH92QN/_apis/pipelines/1/runs/3373307/signedlogcontent/790?urlExpires=2023-09-15T21%3A13%3A51.4327798Z&urlSigningMethod=HMACV1&urlSignature=tDeX7ZqaARVU5NNwyr5yYqqkWq3A2j4z8FFdqYwGr0Q%3D |
| // We should fix this instead. |
| bool needs_resize = false; |
| // We need to resize if: |
| // (1) we haven't allocated our buffer at all yet |
| // (2) Our buffer size is different from the new size |
| // (note: we use the small vector optimization, where our buffer |
| // is always allocated to at least size 5, and any resizes |
| // within the <= 5 regime to not require a reallocation). |
| auto is_smallvector = curr_size <= 5; |
| needs_resize = !is_buffer_allocated || (is_smallvector && new_size > 5) || |
| (!is_smallvector && curr_size != new_size); |
| if (needs_resize) { |
| // If our current buffer is not the right size (either because we haven't |
| // allocated it yet, or there was a metadata mutation that changed the |
| // number of dims of the tensor), allocate a fresh buffer. Note that this |
| // will trash the previous buffer if there already was one, invalidating any |
| // existing SymIntArrayRef's from an old .sym_size() call. |
| auto new_buffer_size = new_size; |
| if (new_size <= 5) { |
| // This is the smallvector optimization |
| new_buffer_size = 5; |
| } |
| T* ptr = new T[new_buffer_size]; |
| auto capsule = |
| py::capsule(ptr, [](void* p) { delete[] reinterpret_cast<T*>(p); }); |
| int64_t idx = 0; |
| for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) { |
| ptr[idx] = py::cast<T>(*it); |
| } |
| // Set the buffer |
| set_tensor_attr_with_capsule(tensor, capsule, base_attr_name); |
| // Set the len buffer |
| py::handle(tensor_obj).attr(buffer_len_attr_name.c_str()) = new_size; |
| } else { |
| TORCH_INTERNAL_ASSERT(PyObject_HasAttrString(tensor_obj, base_attr_name)); |
| auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name); |
| void* buffer_pycapsule = |
| PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr); |
| auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule); |
| |
| // Overwrite the buffer with our new values, but only if any of them changed |
| // (due to a metadata mutation). |
| // This is technically not thread safe, because the update happens lazily. |
| // The original metadata mutation call on the tensor might have been thread |
| // safe (e.g. a .resize_() call), but we won't actually mutate the size |
| // buffer until the first call to .sizes() which the user might not access |
| // in a thread-safe way. For now we are not explicitly locking, but maybe we |
| // should. |
| int64_t idx = 0; |
| // Quick sanity assert that our buffer size is large enough |
| // to compare against all the elements in the new buffer. |
| size_t curr_buffer_size = 5; |
| if (curr_buffer_size < curr_size) { |
| curr_buffer_size = curr_size; |
| } |
| TORCH_INTERNAL_ASSERT(curr_buffer_size >= new_size); |
| for (auto it = obj.begin(); it != obj.end(); ++it, ++idx) { |
| auto actual_val = py::cast<T>(*it); |
| if constexpr (std::is_same_v<T, c10::SymInt>) { |
| // if our SymInts are symbolic, we are *not* doing an equality check on |
| // the symints. we just want to see if the nodes are the same. this is |
| // because we don't want to introduce any guards here. |
| if (!curr_buffer[idx].is_same(actual_val)) { |
| curr_buffer[idx] = actual_val; |
| } |
| } else { |
| if (curr_buffer[idx] != actual_val) { |
| curr_buffer[idx] = actual_val; |
| } |
| } |
| } |
| } |
| |
| // The correct data is now stored at the buffer - read and return it. |
| auto curr_buffer_pyobj = py::handle(tensor_obj).attr(base_attr_name); |
| void* buffer_pycapsule = |
| PyCapsule_GetPointer(curr_buffer_pyobj.ptr(), nullptr); |
| auto curr_buffer = reinterpret_cast<T*>(buffer_pycapsule); |
| return c10::ArrayRef<T>(curr_buffer, new_size); |
| } |
| |
| c10::IntArrayRef ConcretePyInterpreterVTable::strides( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "stride", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("stride") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| TORCH_CHECK( |
| !self->has_symbolic_sizes_strides(), |
| "Cannot call strides on a tensor with symbolic shapes/strides"); |
| return self->strides_default(); |
| } |
| TORCH_CHECK( |
| py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
| "strides must be a list or a tuple"); |
| auto updated_strides = |
| get_set_cached_attr<int64_t>(self, "_strides_capsule", out); |
| return updated_strides; |
| } |
| |
| c10::IntArrayRef ConcretePyInterpreterVTable::sizes( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| HANDLE_TH_ERRORS |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "size", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("size") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| if (out.is_none()) { |
| TORCH_CHECK( |
| !self->has_symbolic_sizes_strides(), |
| "Cannot call sizes on a tensor with symbolic shapes/strides"); |
| return self->sizes_default(); |
| } |
| TORCH_CHECK( |
| py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
| "sizes must be a list or a tuple"); |
| |
| auto updated_sizes = |
| get_set_cached_attr<int64_t>(self, "_sizes_capsule", out); |
| return updated_sizes; |
| END_HANDLE_TH_ERRORS_PYBIND |
| } |
| |
| c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_sizes( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| HANDLE_TH_ERRORS |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_size", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_size") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->sym_sizes_default(); |
| } |
| TORCH_CHECK( |
| py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
| "sym_size must be a list or a tuple"); |
| |
| // See Note [Tensor Subclass custom size/stride caching strategy] |
| auto updated_sym_sizes = |
| get_set_cached_attr<c10::SymInt>(self, "_sym_sizes_capsule", out); |
| return updated_sym_sizes; |
| END_HANDLE_TH_ERRORS_PYBIND |
| } |
| |
| c10::Layout ConcretePyInterpreterVTable::layout( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "layout", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("prim") |
| .attr("layout") |
| .attr("default") |
| .ptr(), |
| "torch.ops.prim"); |
| |
| TORCH_CHECK( |
| THPLayout_Check(out.ptr()) || PyLong_Check(out.ptr()), |
| "layout returned invalid type ", |
| py::detail::get_fully_qualified_tp_name(Py_TYPE(out.ptr())), |
| ", expected Layout"); |
| |
| if (THPLayout_Check(out.ptr())) { |
| return toLayout(out.ptr()); |
| } else { |
| return c10::Layout(py::cast<int64_t>(out)); |
| } |
| } |
| |
| int64_t ConcretePyInterpreterVTable::numel(const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "numel", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("numel") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| TORCH_CHECK( |
| !self->has_symbolic_sizes_strides(), |
| "Cannot call sizes on a tensor with symbolic shapes/strides"); |
| return self->numel_default(); |
| } |
| return py::cast<int64_t>(out); |
| } |
| |
| c10::SymInt ConcretePyInterpreterVTable::sym_numel( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_numel", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_numel") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->sym_numel_default(); |
| } |
| return torch::is_symint(out) ? out.cast<c10::SymInt>() |
| : c10::SymInt{py::cast<int64_t>(out)}; |
| } |
| |
| c10::SymInt ConcretePyInterpreterVTable::sym_storage_offset( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_storage_offset", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_storage_offset") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->sym_storage_offset_default(); |
| } |
| return torch::is_symint(out) ? out.cast<c10::SymInt>() |
| : c10::SymInt{py::cast<int64_t>(out)}; |
| } |
| |
| c10::SymIntArrayRef ConcretePyInterpreterVTable::sym_strides( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| HANDLE_TH_ERRORS |
| auto out = torchDispatchFromTensorImpl( |
| self, |
| "sym_stride", |
| py::module::import("torch") |
| .attr("ops") |
| .attr("aten") |
| .attr("sym_stride") |
| .attr("default") |
| .ptr(), |
| "torch.ops.aten"); |
| |
| if (out.is_none()) { |
| return self->sym_strides_default(); |
| } |
| // We need to squeeze SymIntNodes and ints into `SymInts` |
| // since it's a format `sym_strides()` are stored in |
| TORCH_CHECK( |
| py::isinstance<py::tuple>(out) || py::isinstance<py::list>(out), |
| "sym_strides must be a list or a tuple"); |
| |
| auto updated_sym_strides = |
| get_set_cached_attr<c10::SymInt>(self, "_sym_strides_capsule", out); |
| return updated_sym_strides; |
| END_HANDLE_TH_ERRORS_PYBIND |
| } |
| |
| void ConcretePyInterpreterVTable::reset_backward_hooks( |
| const c10::TensorImpl* self) const { |
| pybind11::gil_scoped_acquire gil; |
| at::impl::MaybeSetTLSOnEntryGuard guard; |
| HANDLE_TH_ERRORS |
| Tensor self_t = |
| Tensor(c10::intrusive_ptr<c10::TensorImpl, c10::UndefinedTensorImpl>:: |
| // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
| unsafe_reclaim_from_nonowning(const_cast<c10::TensorImpl*>(self))); |
| auto self_p = |
| py::reinterpret_steal<py::object>(THPVariable_Wrap(std::move(self_t))); |
| PyObject_SetAttrString(self_p.ptr(), "_backward_hooks", Py_None); |
| END_HANDLE_TH_ERRORS_PYBIND |
| } |
| |
| PyInterpreterHolder self_interpreter; |
| |
| } // anonymous namespace |
| |
| c10::impl::PyInterpreter* getPyInterpreter() { |
| return self_interpreter.get(); |
| } |
| |
| bool isMainPyInterpreter() { |
| return self_interpreter.is_main_interpreter(); |
| } |
| |
| std::string ConcretePyInterpreterVTable::name() const { |
| std::stringstream ss; |
| ss << getPyInterpreter(); |
| return ss.str(); |
| } |