| #define PY_SSIZE_T_CLEAN |
| #include <torch/csrc/dynamo/guards.h> |
| #include <torch/csrc/utils/python_numbers.h> |
| #include <torch/extension.h> |
| #include <sstream> |
| |
| namespace { |
| |
| struct LocalState { |
| // TLS state that changes operators |
| c10::impl::LocalDispatchKeySet dispatch_modifier; |
| bool grad_mode_enabled; |
| |
| at::DispatchKeySet apply(at::DispatchKeySet ks) const { |
| return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_; |
| } |
| |
| LocalState() |
| : dispatch_modifier(c10::impl::tls_local_dispatch_key_set()), |
| grad_mode_enabled(at::GradMode::is_enabled()) {} |
| }; |
| |
| class TensorCheck { |
| public: |
| TensorCheck( |
| const LocalState& state, |
| PyTypeObject* pt, |
| const at::Tensor& v, |
| std::vector<std::optional<int64_t>> dynamic_dims_sizes, |
| std::vector<std::optional<int64_t>> dynamic_dims_strides) |
| : pytype(pt), |
| dispatch_key_(state.apply(v.key_set()).raw_repr()), |
| dtype_(v.dtype().toScalarType()), |
| device_index_(v.device().index()), |
| requires_grad_(state.grad_mode_enabled && v.requires_grad()), |
| sizes_(std::move(dynamic_dims_sizes)), |
| strides_(std::move(dynamic_dims_strides)) { |
| // TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should |
| // we just treat this as optional? |
| dim_ = sizes_.size(); |
| } |
| |
| // See note in guards.py [Note - On Export Tensor Guards] |
| // Logic parallel to here must be maintained in python |
| bool check(const LocalState& state, const at::Tensor& v) { |
| if (dispatch_key_ != state.apply(v.key_set()).raw_repr() || |
| dtype_ != v.dtype().toScalarType() || |
| device_index_ != v.device().index() || |
| requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) { |
| return false; |
| } |
| auto ndim = v.ndimension(); |
| if (ndim != dim_) { |
| return false; |
| } |
| const auto& sizes = v.sizes(); |
| const auto& strides = v.strides(); |
| for (auto i : c10::irange(ndim)) { |
| auto known_size = sizes_[i]; |
| auto known_stride = strides_[i]; |
| if (known_size.has_value()) { |
| if (known_size.value() != sizes[i]) { |
| return false; |
| } |
| } |
| if (known_stride.has_value()) { |
| if (known_stride.value() != strides[i]) { |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| |
| std::string check_verbose( |
| const LocalState& state, |
| const at::Tensor& v, |
| std::string tensor_name) { |
| std::stringstream fail_reason; |
| fail_reason << "tensor '" << tensor_name << "' "; |
| if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) { |
| // return fmt::format("tensor dispatch key mismatch. expected {}, actual |
| // {}", dispatch_key_, state.apply(v.key_set()).raw_repr()); |
| fail_reason << "dispatch key set mismatch. expected " |
| << c10::DispatchKeySet( |
| c10::DispatchKeySet::RAW, dispatch_key_) |
| << ", actual " << state.apply(v.key_set()); |
| return fail_reason.str(); |
| } else if (dtype_ != v.dtype().toScalarType()) { |
| // return fmt::format("tensor dtype mismatch. expected {}, actual {}", |
| // dtype_, v.dtype().toScalarType()); |
| fail_reason << "dtype mismatch. expected " << dtype_ << ", actual " |
| << v.dtype().toScalarType(); |
| return fail_reason.str(); |
| } else if (device_index_ != v.device().index()) { |
| fail_reason |
| << "Tensor device index mismatch. Expected device index to be " |
| << device_index_ << ", actual " << v.device().index(); |
| return fail_reason.str(); |
| } else if ( |
| requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) { |
| // return fmt::format("tensor requires_grad mismatch. expected {}", |
| // requires_grad_); |
| fail_reason << "requires_grad mismatch. expected requires_grad=" |
| << requires_grad_; |
| return fail_reason.str(); |
| } |
| auto ndim = v.ndimension(); |
| if (ndim != dim_) { |
| // return fmt::format("tensor rank mismatch. expected {}, actual {}", |
| // sizes_.size(), ndim); |
| fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual " |
| << ndim; |
| return fail_reason.str(); |
| } |
| const auto& sizes = v.sizes(); |
| const auto& strides = v.strides(); |
| for (auto i : c10::irange(ndim)) { |
| auto known_size = sizes_[i]; |
| auto known_stride = strides_[i]; |
| if (known_size.has_value() && (known_size.value() != sizes[i])) { |
| fail_reason << "size mismatch at index " << i << ". expected " |
| << known_size.value() << ", actual " << sizes[i]; |
| return fail_reason.str(); |
| } |
| if (known_stride.has_value() && known_stride.value() != strides[i]) { |
| fail_reason << "stride mismatch at index " << i << ". expected " |
| << known_stride.value() << ", actual " << strides[i]; |
| return fail_reason.str(); |
| } |
| } |
| return ""; |
| } |
| |
| PyTypeObject* pytype; |
| |
| private: |
| uint64_t dispatch_key_; // DispatchKeySet includes device/layout |
| at::ScalarType dtype_; |
| // Note(voz): While dispatch_key_ is sufficiently representative of a device |
| // In that keys are more granular AND device specific - they do not |
| // necessarily capture device indices correctly. |
| at::DeviceIndex device_index_; |
| bool requires_grad_; |
| // NB: These are unset if dynamic shapes is enabled. |
| std::vector<std::optional<int64_t>> sizes_; |
| std::vector<std::optional<int64_t>> strides_; |
| // Not strictly required for dense tensors, but nested tensors need it. |
| int64_t dim_; |
| }; |
| |
| typedef std::vector<TensorCheck> ChecksList; |
| |
| typedef struct { |
| PyObject_HEAD; |
| ChecksList* checks; |
| } TensorGuards; |
| |
| static void TensorGuards_dealloc(TensorGuards* self) { |
| if (self->checks != NULL) { |
| delete self->checks; |
| self->checks = NULL; |
| } |
| Py_TYPE(self)->tp_free((PyObject*)self); |
| } |
| |
| static PyObject* TensorGuards_new( |
| PyTypeObject* type, |
| PyObject* args, |
| PyObject* kwds) { |
| TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0); |
| if (self != NULL) { |
| self->checks = new ChecksList(); |
| } |
| return (PyObject*)self; |
| } |
| |
| static std::vector<std::optional<int64_t>> wrapIntegersInOptional( |
| const c10::IntArrayRef& intArray) { |
| std::vector<std::optional<int64_t>> optVec(intArray.size()); |
| std::transform( |
| intArray.begin(), intArray.end(), optVec.begin(), [](int64_t value) { |
| return std::make_optional(value); |
| }); |
| return optVec; |
| } |
| |
| static std::vector<std::optional<int64_t>> pyListToVecOptInt(PyObject* pyList) { |
| std::vector<std::optional<int64_t>> vec; |
| Py_ssize_t size = PyList_Size(pyList); |
| for (Py_ssize_t i = 0; i < size; i++) { |
| PyObject* item = PyList_GetItem(pyList, i); |
| if (item == Py_None) { |
| vec.push_back(std::nullopt); |
| } else { |
| int64_t value = PyLong_AsLongLong(item); |
| if (value == -1 && PyErr_Occurred()) { |
| PyErr_SetString( |
| PyExc_TypeError, |
| "Size or stride list item is not a valid integer."); |
| TORCH_CHECK(false, "Size or stride list item is not a valid integer."); |
| } |
| vec.push_back(value); |
| } |
| } |
| return vec; |
| } |
| |
| static std::vector<std::vector<std::optional<int64_t>>> get_dynamic_dims( |
| PyObject* dynamic_dims_py) { |
| std::vector<std::vector<std::optional<int64_t>>> per_tensor_dynamic_dims; |
| if (dynamic_dims_py != Py_None) { |
| Py_ssize_t size = PyList_Size(dynamic_dims_py); |
| for (Py_ssize_t i = 0; i < size; i++) { |
| PyObject* py_list = PyList_GetItem(dynamic_dims_py, i); |
| std::vector<std::optional<int64_t>> vec = pyListToVecOptInt(py_list); |
| per_tensor_dynamic_dims.push_back(std::move(vec)); |
| } |
| } |
| return per_tensor_dynamic_dims; |
| } |
| |
| static int TensorGuards_init( |
| TensorGuards* self, |
| PyObject* args, |
| PyObject* kwds) { |
| if (!PyTuple_CheckExact(args)) { |
| PyErr_SetString(PyExc_TypeError, "expected tuple()"); |
| return -1; |
| } |
| // Top level structure is List[List[Union[int, None]]] |
| PyObject* dynamic_dims_sizes_py = |
| PyDict_GetItemString(kwds, "dynamic_dims_sizes"); |
| if (dynamic_dims_sizes_py == NULL) { |
| PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=..."); |
| return -1; |
| } |
| PyObject* dynamic_dims_strides_py = |
| PyDict_GetItemString(kwds, "dynamic_dims_strides"); |
| if (dynamic_dims_strides_py == NULL) { |
| PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=..."); |
| return -1; |
| } |
| |
| // dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is |
| // an optimization to avoid invoking .size()/.stride() in python needlessly |
| std::vector<std::vector<std::optional<int64_t>>> |
| per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py); |
| std::vector<std::vector<std::optional<int64_t>>> |
| per_tensor_dynamic_dims_strides = |
| get_dynamic_dims(dynamic_dims_strides_py); |
| |
| auto& checks = *self->checks; |
| auto len = PyTuple_GET_SIZE(args); |
| checks.reserve(len); |
| LocalState state; |
| for (auto i : c10::irange(len)) { |
| PyObject* item = PyTuple_GET_ITEM(args, i); |
| if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { |
| PyErr_SetString(PyExc_TypeError, "expected Tensor()"); |
| return -1; |
| } |
| auto tensor = THPVariable_Unpack(item); |
| std::vector<std::optional<int64_t>> tensor_dims_size = |
| per_tensor_dynamic_dims_sizes.size() == 0 |
| ? wrapIntegersInOptional(tensor.sizes()) |
| : per_tensor_dynamic_dims_sizes[i]; |
| std::vector<std::optional<int64_t>> tensor_dims_stride = |
| per_tensor_dynamic_dims_strides.size() == 0 |
| ? wrapIntegersInOptional(tensor.strides()) |
| : per_tensor_dynamic_dims_strides[i]; |
| checks.emplace_back( |
| state, |
| Py_TYPE(item), |
| std::move(tensor), |
| std::move(tensor_dims_size), |
| std::move(tensor_dims_stride)); |
| } |
| return 0; |
| } |
| |
| PyObject* TensorGuards_check(TensorGuards* self, PyObject* args) { |
| if (!PyTuple_CheckExact(args)) { |
| PyErr_SetString(PyExc_TypeError, "expected tuple()"); |
| return NULL; |
| } |
| auto& checks = *self->checks; |
| auto len = PyTuple_GET_SIZE(args); |
| |
| if (static_cast<decltype(len)>(checks.size()) != len) { |
| PyErr_SetString(PyExc_TypeError, "wrong length"); |
| return NULL; |
| } |
| |
| LocalState state; |
| |
| for (auto i : c10::irange(len)) { |
| PyObject* item = PyTuple_GET_ITEM(args, i); |
| if (Py_TYPE(item) != checks[i].pytype) { |
| Py_RETURN_FALSE; |
| } |
| if (!checks[i].check(state, THPVariable_Unpack(item))) { |
| Py_RETURN_FALSE; |
| } |
| } |
| |
| Py_RETURN_TRUE; |
| } |
| |
| PyObject* TensorGuards_check_verbose( |
| TensorGuards* self, |
| PyObject* args, |
| PyObject* kwargs) { |
| if (!PyTuple_CheckExact(args)) { |
| PyErr_SetString(PyExc_TypeError, "expected tuple()"); |
| return NULL; |
| } |
| auto& checks = *self->checks; |
| auto len = PyTuple_GET_SIZE(args); |
| |
| if (static_cast<decltype(len)>(checks.size()) != len) { |
| PyErr_SetString(PyExc_TypeError, "wrong length"); |
| return NULL; |
| } |
| |
| PyObject* tensor_check_names_py = |
| PyDict_GetItemString(kwargs, "tensor_check_names"); |
| if (tensor_check_names_py == NULL) { |
| PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg"); |
| return NULL; |
| } |
| |
| if (!PyList_Check(tensor_check_names_py)) { |
| PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list"); |
| return NULL; |
| } |
| |
| auto names_size = PyList_Size(tensor_check_names_py); |
| if (names_size != static_cast<decltype(names_size)>(checks.size())) { |
| PyErr_SetString( |
| PyExc_TypeError, |
| "tensor_check_names should be the same size as # tensors"); |
| return NULL; |
| } |
| |
| std::vector<std::string> tensor_check_names; |
| tensor_check_names.reserve(names_size); |
| for (auto i : c10::irange(names_size)) { |
| PyObject* value = PyList_GetItem(tensor_check_names_py, i); |
| if (!PyUnicode_Check(value)) { |
| PyErr_SetString( |
| PyExc_TypeError, "tensor_check_names must only contain strings"); |
| return NULL; |
| } |
| tensor_check_names.emplace_back(PyUnicode_AsUTF8(value)); |
| } |
| |
| LocalState state; |
| for (auto i : c10::irange(len)) { |
| PyObject* item = PyTuple_GET_ITEM(args, i); |
| if (Py_TYPE(item) != checks[i].pytype) { |
| std::stringstream fail_reason; |
| PyObject* type_str = PyObject_Str(PyObject_Type(item)); |
| fail_reason << "expected type of '" << tensor_check_names[i] |
| << "' to be a tensor type, "; |
| if (!type_str) { |
| fail_reason << "but found a different type"; |
| } else { |
| fail_reason << "' but found " << PyUnicode_AsUTF8(type_str); |
| } |
| return Py_BuildValue("s", fail_reason.str().c_str()); |
| } |
| std::string fail_reason = checks[i].check_verbose( |
| state, THPVariable_Unpack(item), tensor_check_names[i]); |
| if (fail_reason.length() > 0) { |
| return Py_BuildValue("s", fail_reason.c_str()); |
| } |
| } |
| |
| Py_RETURN_TRUE; |
| } |
| |
| static PyMethodDef TensorGuards_methods[] = { |
| {"check", (PyCFunction)TensorGuards_check, METH_VARARGS, ""}, |
| {"check_verbose", |
| (PyCFunction)(void*)TensorGuards_check_verbose, |
| METH_VARARGS | METH_KEYWORDS, |
| "verbose fail reasons for failed checks"}, |
| {NULL} /* Sentinel */ |
| }; |
| |
| static PyTypeObject TensorGuardsType = { |
| // NOLINTNEXTLINE |
| PyVarObject_HEAD_INIT(NULL, 0)}; |
| |
| static PyObject* check_type_id(PyObject* dummy, PyObject* args) { |
| // faster `lambda obj, expected: id(type(obj)) == expected` |
| PyObject* obj; |
| unsigned long long expected; |
| if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) { |
| return NULL; |
| } |
| if (Py_TYPE(obj) == (void*)expected) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| } |
| |
| static PyObject* check_obj_id(PyObject* dummy, PyObject* args) { |
| // faster `lambda obj, expected: id(obj) == expected` |
| PyObject* obj; |
| unsigned long long expected; |
| if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) { |
| return NULL; |
| } |
| if (obj == (void*)expected) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| } |
| |
| static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) { |
| /* |
| Assert that a given tensor has a given size/stride, but ignore strides |
| of size==1 dimensions. Implemented in C++ as this is on the hot path. |
| */ |
| PyObject* item; |
| PyObject* size; |
| PyObject* stride; |
| if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) { |
| return NULL; |
| } |
| if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) { |
| PyErr_SetString(PyExc_TypeError, "expected Tensor()"); |
| return NULL; |
| } |
| if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) { |
| PyErr_SetString(PyExc_TypeError, "expected tuple()"); |
| return NULL; |
| } |
| at::Tensor tensor = THPVariable_Unpack(item); |
| int64_t ndim = tensor.ndimension(); |
| if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) { |
| PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions"); |
| return NULL; |
| } |
| for (auto i : c10::irange(ndim)) { |
| int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i)); |
| int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i)); |
| int64_t actual_size = tensor.size(i); |
| int64_t actual_stride = tensor.stride(i); |
| if (want_size != actual_size || |
| // ignore stride differences when size is 1 |
| (want_stride != actual_stride && actual_size > 1)) { |
| std::stringstream msg; |
| msg << "expected size " << actual_size << "==" << want_size << ", stride " |
| << actual_stride << "==" << want_stride << " at dim=" << i; |
| PyErr_SetString(PyExc_AssertionError, msg.str().c_str()); |
| return NULL; |
| } |
| } |
| Py_RETURN_TRUE; |
| } |
| |
| static PyMethodDef _methods[] = { |
| {"check_type_id", check_type_id, METH_VARARGS, NULL}, |
| {"check_obj_id", check_obj_id, METH_VARARGS, NULL}, |
| {"assert_size_stride", assert_size_stride, METH_VARARGS, NULL}, |
| {NULL, NULL, 0, NULL}}; |
| |
| static struct PyModuleDef _module = { |
| PyModuleDef_HEAD_INIT, |
| "torch._C._dynamo.guards", |
| "Module containing checks on tensors", |
| -1, |
| _methods}; |
| |
| } // namespace |
| |
| PyObject* torch_c_dynamo_guards_init() { |
| // initialize TensorGuardsType |
| TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards"; |
| TensorGuardsType.tp_basicsize = sizeof(TensorGuards); |
| TensorGuardsType.tp_itemsize = 0; |
| TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc; |
| TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT; |
| TensorGuardsType.tp_doc = "Check properties of a torch.Tensor"; |
| TensorGuardsType.tp_methods = TensorGuards_methods; |
| TensorGuardsType.tp_init = (initproc)TensorGuards_init; |
| TensorGuardsType.tp_new = TensorGuards_new; |
| |
| PyObject* m; |
| if (PyType_Ready(&TensorGuardsType) < 0) |
| return NULL; |
| |
| m = PyModule_Create(&_module); |
| if (m == NULL) |
| return NULL; |
| |
| Py_INCREF(&TensorGuardsType); |
| if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) { |
| Py_DECREF(&TensorGuardsType); |
| Py_DECREF(m); |
| return NULL; |
| } |
| |
| return m; |
| } |