blob: 8076637aa88848646b93f5322cfc20dc6c7986bb [file] [log] [blame]
#define PY_SSIZE_T_CLEAN
#include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/extension.h>
#include <sstream>
namespace {
struct LocalState {
// TLS state that changes operators
c10::impl::LocalDispatchKeySet dispatch_modifier;
bool grad_mode_enabled;
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
}
LocalState()
: dispatch_modifier(c10::impl::tls_local_dispatch_key_set()),
grad_mode_enabled(at::GradMode::is_enabled()) {}
};
class TensorCheck {
public:
TensorCheck(
const LocalState& state,
PyTypeObject* pt,
const at::Tensor& v,
std::vector<std::optional<int64_t>> dynamic_dims_sizes,
std::vector<std::optional<int64_t>> dynamic_dims_strides)
: pytype(pt),
dispatch_key_(state.apply(v.key_set()).raw_repr()),
dtype_(v.dtype().toScalarType()),
device_index_(v.device().index()),
requires_grad_(state.grad_mode_enabled && v.requires_grad()),
sizes_(std::move(dynamic_dims_sizes)),
strides_(std::move(dynamic_dims_strides)) {
// TODO(voz): In cases where sizes_ and strides_ are fully dynamic, should
// we just treat this as optional?
dim_ = sizes_.size();
}
// See note in guards.py [Note - On Export Tensor Guards]
// Logic parallel to here must be maintained in python
bool check(const LocalState& state, const at::Tensor& v) {
if (dispatch_key_ != state.apply(v.key_set()).raw_repr() ||
dtype_ != v.dtype().toScalarType() ||
device_index_ != v.device().index() ||
requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) {
return false;
}
auto ndim = v.ndimension();
if (ndim != dim_) {
return false;
}
const auto& sizes = v.sizes();
const auto& strides = v.strides();
for (auto i : c10::irange(ndim)) {
auto known_size = sizes_[i];
auto known_stride = strides_[i];
if (known_size.has_value()) {
if (known_size.value() != sizes[i]) {
return false;
}
}
if (known_stride.has_value()) {
if (known_stride.value() != strides[i]) {
return false;
}
}
}
return true;
}
std::string check_verbose(
const LocalState& state,
const at::Tensor& v,
std::string tensor_name) {
std::stringstream fail_reason;
fail_reason << "tensor '" << tensor_name << "' ";
if (dispatch_key_ != state.apply(v.key_set()).raw_repr()) {
// return fmt::format("tensor dispatch key mismatch. expected {}, actual
// {}", dispatch_key_, state.apply(v.key_set()).raw_repr());
fail_reason << "dispatch key set mismatch. expected "
<< c10::DispatchKeySet(
c10::DispatchKeySet::RAW, dispatch_key_)
<< ", actual " << state.apply(v.key_set());
return fail_reason.str();
} else if (dtype_ != v.dtype().toScalarType()) {
// return fmt::format("tensor dtype mismatch. expected {}, actual {}",
// dtype_, v.dtype().toScalarType());
fail_reason << "dtype mismatch. expected " << dtype_ << ", actual "
<< v.dtype().toScalarType();
return fail_reason.str();
} else if (device_index_ != v.device().index()) {
fail_reason
<< "Tensor device index mismatch. Expected device index to be "
<< device_index_ << ", actual " << v.device().index();
return fail_reason.str();
} else if (
requires_grad_ != (state.grad_mode_enabled && v.requires_grad())) {
// return fmt::format("tensor requires_grad mismatch. expected {}",
// requires_grad_);
fail_reason << "requires_grad mismatch. expected requires_grad="
<< requires_grad_;
return fail_reason.str();
}
auto ndim = v.ndimension();
if (ndim != dim_) {
// return fmt::format("tensor rank mismatch. expected {}, actual {}",
// sizes_.size(), ndim);
fail_reason << "rank mismatch. expected " << sizes_.size() << ", actual "
<< ndim;
return fail_reason.str();
}
const auto& sizes = v.sizes();
const auto& strides = v.strides();
for (auto i : c10::irange(ndim)) {
auto known_size = sizes_[i];
auto known_stride = strides_[i];
if (known_size.has_value() && (known_size.value() != sizes[i])) {
fail_reason << "size mismatch at index " << i << ". expected "
<< known_size.value() << ", actual " << sizes[i];
return fail_reason.str();
}
if (known_stride.has_value() && known_stride.value() != strides[i]) {
fail_reason << "stride mismatch at index " << i << ". expected "
<< known_stride.value() << ", actual " << strides[i];
return fail_reason.str();
}
}
return "";
}
PyTypeObject* pytype;
private:
uint64_t dispatch_key_; // DispatchKeySet includes device/layout
at::ScalarType dtype_;
// Note(voz): While dispatch_key_ is sufficiently representative of a device
// In that keys are more granular AND device specific - they do not
// necessarily capture device indices correctly.
at::DeviceIndex device_index_;
bool requires_grad_;
// NB: These are unset if dynamic shapes is enabled.
std::vector<std::optional<int64_t>> sizes_;
std::vector<std::optional<int64_t>> strides_;
// Not strictly required for dense tensors, but nested tensors need it.
int64_t dim_;
};
typedef std::vector<TensorCheck> ChecksList;
typedef struct {
PyObject_HEAD;
ChecksList* checks;
} TensorGuards;
static void TensorGuards_dealloc(TensorGuards* self) {
if (self->checks != NULL) {
delete self->checks;
self->checks = NULL;
}
Py_TYPE(self)->tp_free((PyObject*)self);
}
static PyObject* TensorGuards_new(
PyTypeObject* type,
PyObject* args,
PyObject* kwds) {
TensorGuards* self = (TensorGuards*)type->tp_alloc(type, 0);
if (self != NULL) {
self->checks = new ChecksList();
}
return (PyObject*)self;
}
static std::vector<std::optional<int64_t>> wrapIntegersInOptional(
const c10::IntArrayRef& intArray) {
std::vector<std::optional<int64_t>> optVec(intArray.size());
std::transform(
intArray.begin(), intArray.end(), optVec.begin(), [](int64_t value) {
return std::make_optional(value);
});
return optVec;
}
static std::vector<std::optional<int64_t>> pyListToVecOptInt(PyObject* pyList) {
std::vector<std::optional<int64_t>> vec;
Py_ssize_t size = PyList_Size(pyList);
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* item = PyList_GetItem(pyList, i);
if (item == Py_None) {
vec.push_back(std::nullopt);
} else {
int64_t value = PyLong_AsLongLong(item);
if (value == -1 && PyErr_Occurred()) {
PyErr_SetString(
PyExc_TypeError,
"Size or stride list item is not a valid integer.");
TORCH_CHECK(false, "Size or stride list item is not a valid integer.");
}
vec.push_back(value);
}
}
return vec;
}
static std::vector<std::vector<std::optional<int64_t>>> get_dynamic_dims(
PyObject* dynamic_dims_py) {
std::vector<std::vector<std::optional<int64_t>>> per_tensor_dynamic_dims;
if (dynamic_dims_py != Py_None) {
Py_ssize_t size = PyList_Size(dynamic_dims_py);
for (Py_ssize_t i = 0; i < size; i++) {
PyObject* py_list = PyList_GetItem(dynamic_dims_py, i);
std::vector<std::optional<int64_t>> vec = pyListToVecOptInt(py_list);
per_tensor_dynamic_dims.push_back(std::move(vec));
}
}
return per_tensor_dynamic_dims;
}
static int TensorGuards_init(
TensorGuards* self,
PyObject* args,
PyObject* kwds) {
if (!PyTuple_CheckExact(args)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return -1;
}
// Top level structure is List[List[Union[int, None]]]
PyObject* dynamic_dims_sizes_py =
PyDict_GetItemString(kwds, "dynamic_dims_sizes");
if (dynamic_dims_sizes_py == NULL) {
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_sizes=...");
return -1;
}
PyObject* dynamic_dims_strides_py =
PyDict_GetItemString(kwds, "dynamic_dims_strides");
if (dynamic_dims_strides_py == NULL) {
PyErr_SetString(PyExc_TypeError, "missing dynamic_dims_strides=...");
return -1;
}
// dynamic_dims_strides/sizes_py is None when dynamic_shapes=False - this is
// an optimization to avoid invoking .size()/.stride() in python needlessly
std::vector<std::vector<std::optional<int64_t>>>
per_tensor_dynamic_dims_sizes = get_dynamic_dims(dynamic_dims_sizes_py);
std::vector<std::vector<std::optional<int64_t>>>
per_tensor_dynamic_dims_strides =
get_dynamic_dims(dynamic_dims_strides_py);
auto& checks = *self->checks;
auto len = PyTuple_GET_SIZE(args);
checks.reserve(len);
LocalState state;
for (auto i : c10::irange(len)) {
PyObject* item = PyTuple_GET_ITEM(args, i);
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return -1;
}
auto tensor = THPVariable_Unpack(item);
std::vector<std::optional<int64_t>> tensor_dims_size =
per_tensor_dynamic_dims_sizes.size() == 0
? wrapIntegersInOptional(tensor.sizes())
: per_tensor_dynamic_dims_sizes[i];
std::vector<std::optional<int64_t>> tensor_dims_stride =
per_tensor_dynamic_dims_strides.size() == 0
? wrapIntegersInOptional(tensor.strides())
: per_tensor_dynamic_dims_strides[i];
checks.emplace_back(
state,
Py_TYPE(item),
std::move(tensor),
std::move(tensor_dims_size),
std::move(tensor_dims_stride));
}
return 0;
}
PyObject* TensorGuards_check(TensorGuards* self, PyObject* args) {
if (!PyTuple_CheckExact(args)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return NULL;
}
auto& checks = *self->checks;
auto len = PyTuple_GET_SIZE(args);
if (static_cast<decltype(len)>(checks.size()) != len) {
PyErr_SetString(PyExc_TypeError, "wrong length");
return NULL;
}
LocalState state;
for (auto i : c10::irange(len)) {
PyObject* item = PyTuple_GET_ITEM(args, i);
if (Py_TYPE(item) != checks[i].pytype) {
Py_RETURN_FALSE;
}
if (!checks[i].check(state, THPVariable_Unpack(item))) {
Py_RETURN_FALSE;
}
}
Py_RETURN_TRUE;
}
PyObject* TensorGuards_check_verbose(
TensorGuards* self,
PyObject* args,
PyObject* kwargs) {
if (!PyTuple_CheckExact(args)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return NULL;
}
auto& checks = *self->checks;
auto len = PyTuple_GET_SIZE(args);
if (static_cast<decltype(len)>(checks.size()) != len) {
PyErr_SetString(PyExc_TypeError, "wrong length");
return NULL;
}
PyObject* tensor_check_names_py =
PyDict_GetItemString(kwargs, "tensor_check_names");
if (tensor_check_names_py == NULL) {
PyErr_SetString(PyExc_TypeError, "missing tensor_check_names kwarg");
return NULL;
}
if (!PyList_Check(tensor_check_names_py)) {
PyErr_SetString(PyExc_TypeError, "tensor_check_names kwarg must be a list");
return NULL;
}
auto names_size = PyList_Size(tensor_check_names_py);
if (names_size != static_cast<decltype(names_size)>(checks.size())) {
PyErr_SetString(
PyExc_TypeError,
"tensor_check_names should be the same size as # tensors");
return NULL;
}
std::vector<std::string> tensor_check_names;
tensor_check_names.reserve(names_size);
for (auto i : c10::irange(names_size)) {
PyObject* value = PyList_GetItem(tensor_check_names_py, i);
if (!PyUnicode_Check(value)) {
PyErr_SetString(
PyExc_TypeError, "tensor_check_names must only contain strings");
return NULL;
}
tensor_check_names.emplace_back(PyUnicode_AsUTF8(value));
}
LocalState state;
for (auto i : c10::irange(len)) {
PyObject* item = PyTuple_GET_ITEM(args, i);
if (Py_TYPE(item) != checks[i].pytype) {
std::stringstream fail_reason;
PyObject* type_str = PyObject_Str(PyObject_Type(item));
fail_reason << "expected type of '" << tensor_check_names[i]
<< "' to be a tensor type, ";
if (!type_str) {
fail_reason << "but found a different type";
} else {
fail_reason << "' but found " << PyUnicode_AsUTF8(type_str);
}
return Py_BuildValue("s", fail_reason.str().c_str());
}
std::string fail_reason = checks[i].check_verbose(
state, THPVariable_Unpack(item), tensor_check_names[i]);
if (fail_reason.length() > 0) {
return Py_BuildValue("s", fail_reason.c_str());
}
}
Py_RETURN_TRUE;
}
static PyMethodDef TensorGuards_methods[] = {
{"check", (PyCFunction)TensorGuards_check, METH_VARARGS, ""},
{"check_verbose",
(PyCFunction)(void*)TensorGuards_check_verbose,
METH_VARARGS | METH_KEYWORDS,
"verbose fail reasons for failed checks"},
{NULL} /* Sentinel */
};
static PyTypeObject TensorGuardsType = {
// NOLINTNEXTLINE
PyVarObject_HEAD_INIT(NULL, 0)};
static PyObject* check_type_id(PyObject* dummy, PyObject* args) {
// faster `lambda obj, expected: id(type(obj)) == expected`
PyObject* obj;
unsigned long long expected;
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
return NULL;
}
if (Py_TYPE(obj) == (void*)expected) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
static PyObject* check_obj_id(PyObject* dummy, PyObject* args) {
// faster `lambda obj, expected: id(obj) == expected`
PyObject* obj;
unsigned long long expected;
if (!PyArg_ParseTuple(args, "OK", &obj, &expected)) {
return NULL;
}
if (obj == (void*)expected) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
}
static PyObject* assert_size_stride(PyObject* dummy, PyObject* args) {
/*
Assert that a given tensor has a given size/stride, but ignore strides
of size==1 dimensions. Implemented in C++ as this is on the hot path.
*/
PyObject* item;
PyObject* size;
PyObject* stride;
if (!PyArg_ParseTuple(args, "OOO", &item, &size, &stride)) {
return NULL;
}
if (!THPVariable_CheckExact(item) && !THPVariable_Check(item)) {
PyErr_SetString(PyExc_TypeError, "expected Tensor()");
return NULL;
}
if (!PyTuple_CheckExact(size) || !PyTuple_CheckExact(stride)) {
PyErr_SetString(PyExc_TypeError, "expected tuple()");
return NULL;
}
at::Tensor tensor = THPVariable_Unpack(item);
int64_t ndim = tensor.ndimension();
if (PyTuple_GET_SIZE(size) != ndim || PyTuple_GET_SIZE(stride) != ndim) {
PyErr_SetString(PyExc_AssertionError, "wrong number of dimensions");
return NULL;
}
for (auto i : c10::irange(ndim)) {
int64_t want_size = THPUtils_unpackLong(PyTuple_GET_ITEM(size, i));
int64_t want_stride = THPUtils_unpackLong(PyTuple_GET_ITEM(stride, i));
int64_t actual_size = tensor.size(i);
int64_t actual_stride = tensor.stride(i);
if (want_size != actual_size ||
// ignore stride differences when size is 1
(want_stride != actual_stride && actual_size > 1)) {
std::stringstream msg;
msg << "expected size " << actual_size << "==" << want_size << ", stride "
<< actual_stride << "==" << want_stride << " at dim=" << i;
PyErr_SetString(PyExc_AssertionError, msg.str().c_str());
return NULL;
}
}
Py_RETURN_TRUE;
}
static PyMethodDef _methods[] = {
{"check_type_id", check_type_id, METH_VARARGS, NULL},
{"check_obj_id", check_obj_id, METH_VARARGS, NULL},
{"assert_size_stride", assert_size_stride, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}};
static struct PyModuleDef _module = {
PyModuleDef_HEAD_INIT,
"torch._C._dynamo.guards",
"Module containing checks on tensors",
-1,
_methods};
} // namespace
PyObject* torch_c_dynamo_guards_init() {
// initialize TensorGuardsType
TensorGuardsType.tp_name = "torch._C._dynamo.guards.TensorGuards";
TensorGuardsType.tp_basicsize = sizeof(TensorGuards);
TensorGuardsType.tp_itemsize = 0;
TensorGuardsType.tp_dealloc = (destructor)TensorGuards_dealloc;
TensorGuardsType.tp_flags = Py_TPFLAGS_DEFAULT;
TensorGuardsType.tp_doc = "Check properties of a torch.Tensor";
TensorGuardsType.tp_methods = TensorGuards_methods;
TensorGuardsType.tp_init = (initproc)TensorGuards_init;
TensorGuardsType.tp_new = TensorGuards_new;
PyObject* m;
if (PyType_Ready(&TensorGuardsType) < 0)
return NULL;
m = PyModule_Create(&_module);
if (m == NULL)
return NULL;
Py_INCREF(&TensorGuardsType);
if (PyModule_AddObject(m, "TensorGuards", (PyObject*)&TensorGuardsType) < 0) {
Py_DECREF(&TensorGuardsType);
Py_DECREF(m);
return NULL;
}
return m;
}