| #include <torch/csrc/Device.h> |
| |
| #include <torch/csrc/Exceptions.h> |
| #include <torch/csrc/utils/object_ptr.h> |
| #include <torch/csrc/utils/pybind.h> |
| #include <torch/csrc/utils/python_arg_parser.h> |
| #include <torch/csrc/utils/python_numbers.h> |
| #include <torch/csrc/utils/python_strings.h> |
| |
| #include <ATen/Device.h> |
| #include <c10/util/Exception.h> |
| |
| #include <structmember.h> |
| #include <limits> |
| #include <sstream> |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| PyObject* THPUpperModuleOfDevice = nullptr; |
| |
| PyObject* THPDevice_New(const at::Device& device) { |
| auto type = (PyTypeObject*)&THPDeviceType; |
| auto self = THPObjectPtr{type->tp_alloc(type, 0)}; |
| if (!self) |
| throw python_error(); |
| auto self_ = reinterpret_cast<THPDevice*>(self.get()); |
| self_->device = device; |
| return self.release(); |
| } |
| |
| PyObject* THPDevice_repr(THPDevice* self) { |
| std::ostringstream oss; |
| oss << "device(type=\'" << self->device.type() << "\'"; |
| if (self->device.has_index()) { |
| // `self->device.index()` returns uint8_t which is treated as ascii while |
| // printing, hence casting it to uint16_t. |
| // https://stackoverflow.com/questions/19562103/uint8-t-cant-be-printed-with-cout |
| oss << ", index=" << static_cast<uint16_t>(self->device.index()); |
| } |
| oss << ")"; |
| return THPUtils_packString(oss.str().c_str()); |
| } |
| |
| PyObject* THPDevice_str(THPDevice* self) { |
| std::ostringstream oss; |
| oss << self->device; |
| return THPUtils_packString(oss.str().c_str()); |
| } |
| |
| PyObject* THPDevice_pynew( |
| PyTypeObject* type, |
| PyObject* args, |
| PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| static torch::PythonArgParser parser( |
| {"device(Device device)", |
| "device(c10::string_view type, int64_t? index=-1)"}); |
| torch::ParsedArgs<2> parsed_args; |
| auto r = parser.parse(args, kwargs, parsed_args); |
| if (r.has_torch_function()) { |
| return handle_torch_function( |
| r, nullptr, args, kwargs, THPUpperModuleOfDevice, "torch"); |
| } |
| if (r.idx == 0) { |
| auto device = r.device(0); |
| return THPDevice_New(device); |
| } else if (r.idx == 1) { |
| auto as_device = r.device(0); // this works, because device can take strings |
| if (as_device.has_index()) { |
| auto device_type = r.string(0); |
| throw std::runtime_error( |
| "type (string) must not include an index because index " |
| "was passed explicitly: " + |
| device_type); |
| } |
| int64_t device_index = -1; |
| if (!r.isNone(1)) { |
| device_index = r.toInt64(1); |
| // -1 is allowed in ATen/C++, to mean the default device, but not in |
| // Python. |
| TORCH_CHECK(device_index >= 0, "Device index must not be negative"); |
| } |
| at::Device device( |
| as_device.type(), static_cast<c10::DeviceIndex>(device_index)); |
| return THPDevice_New(device); |
| } |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPDevice_type(THPDevice* self, PyObject* noargs) { |
| HANDLE_TH_ERRORS |
| std::ostringstream oss; |
| oss << self->device.type(); |
| return THPUtils_packString(oss.str().c_str()); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPDevice_index(THPDevice* self, PyObject* noargs) { |
| HANDLE_TH_ERRORS |
| if (self->device.has_index()) { |
| return THPUtils_packInt64(self->device.index()); |
| } else { |
| Py_RETURN_NONE; |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| static Py_ssize_t THPDevice_hash(THPDevice* self) { |
| HANDLE_TH_ERRORS |
| return static_cast<Py_ssize_t>( |
| std::hash<at::Device>{}(self->device) % |
| std::numeric_limits<Py_ssize_t>::max()); |
| END_HANDLE_TH_ERRORS_RET(-1) |
| } |
| |
| PyObject* THPDevice_rc(PyObject* a, PyObject* b, int op) { |
| HANDLE_TH_ERRORS |
| if (!THPDevice_Check(a) || !THPDevice_Check(b)) { |
| // Py_RETURN_NOTIMPLEMENTED not in python 2. |
| Py_INCREF(Py_NotImplemented); |
| return Py_NotImplemented; |
| } |
| THPDevice* da = reinterpret_cast<THPDevice*>(a); |
| THPDevice* db = reinterpret_cast<THPDevice*>(b); |
| |
| switch (op) { |
| case Py_EQ: |
| if (da->device == db->device) { |
| Py_RETURN_TRUE; |
| } else { |
| Py_RETURN_FALSE; |
| } |
| case Py_NE: |
| if (da->device == db->device) { |
| Py_RETURN_FALSE; |
| } else { |
| Py_RETURN_TRUE; |
| } |
| case Py_LT: |
| case Py_LE: |
| case Py_GT: |
| case Py_GE: |
| throw torch::TypeError("comparison not implemented"); |
| default: |
| throw torch::TypeError("unexpected comparison op"); |
| } |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPDevice_reduce(PyObject* _self, PyObject* noargs) { |
| HANDLE_TH_ERRORS |
| auto self = (THPDevice*)_self; |
| auto ret = THPObjectPtr{PyTuple_New(2)}; |
| if (!ret) |
| throw python_error(); |
| |
| py::object torch_module = py::module::import("torch"); |
| py::object torch_device = torch_module.attr("device"); |
| PyTuple_SET_ITEM(ret.get(), 0, torch_device.release().ptr()); |
| |
| THPObjectPtr args; |
| std::ostringstream oss; |
| oss << self->device.type(); |
| if (self->device.has_index()) { |
| args = THPObjectPtr{Py_BuildValue( |
| "(si)", oss.str().c_str(), static_cast<int>(self->device.index()))}; |
| } else { |
| args = THPObjectPtr{Py_BuildValue("(s)", oss.str().c_str())}; |
| } |
| if (!args) |
| throw python_error(); |
| PyTuple_SET_ITEM(ret.get(), 1, args.release()); |
| |
| return ret.release(); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPDevice_enter(PyObject* self, PyObject* noargs) { |
| HANDLE_TH_ERRORS |
| py::object mode = py::module::import("torch.utils._device") |
| .attr("DeviceContext")(py::handle(self)); |
| at::impl::PythonTorchFunctionTLS::push_onto_stack( |
| std::make_shared<c10::SafePyObject>( |
| mode.release().ptr(), getPyInterpreter())); |
| // So that with torch.device('cuda') as dev: works |
| Py_INCREF(self); |
| return self; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPDevice_exit(PyObject* self, PyObject* unused) { |
| HANDLE_TH_ERRORS |
| at::impl::PythonTorchFunctionTLS::pop_stack(); |
| Py_RETURN_NONE; |
| END_HANDLE_TH_ERRORS |
| } |
| |
| PyObject* THPDevice_call(PyObject* self, PyObject* args, PyObject* kwargs) { |
| HANDLE_TH_ERRORS |
| py::object deco = |
| py::module::import("torch.utils._device").attr("device_decorator"); |
| return deco(py::handle(self), *py::handle(args), **py::handle(kwargs)) |
| .release() |
| .ptr(); |
| END_HANDLE_TH_ERRORS |
| } |
| |
| typedef PyObject* (*getter)(PyObject*, void*); |
| |
| // NB: If you edit these properties/methods, update torch/_C/__init__.pyi.in |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
| static struct PyGetSetDef THPDevice_properties[] = { |
| {"type", (getter)THPDevice_type, nullptr, nullptr, nullptr}, |
| {"index", (getter)THPDevice_index, nullptr, nullptr, nullptr}, |
| {nullptr}}; |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays) |
| static PyMethodDef THPDevice_methods[] = { |
| {"__reduce__", THPDevice_reduce, METH_NOARGS, nullptr}, |
| {"__enter__", THPDevice_enter, METH_NOARGS, nullptr}, |
| {"__exit__", THPDevice_exit, METH_VARARGS, nullptr}, |
| {nullptr} /* Sentinel */ |
| }; |
| |
| PyTypeObject THPDeviceType = { |
| PyVarObject_HEAD_INIT(nullptr, 0) "torch.device", /* tp_name */ |
| sizeof(THPDevice), /* tp_basicsize */ |
| 0, /* tp_itemsize */ |
| nullptr, /* tp_dealloc */ |
| 0, /* tp_vectorcall_offset */ |
| nullptr, /* tp_getattr */ |
| nullptr, /* tp_setattr */ |
| nullptr, /* tp_reserved */ |
| (reprfunc)THPDevice_repr, /* tp_repr */ |
| nullptr, /* tp_as_number */ |
| nullptr, /* tp_as_sequence */ |
| nullptr, /* tp_as_mapping */ |
| (hashfunc)THPDevice_hash, /* tp_hash */ |
| // TODO: We're not sure if this is a good idea or not, because making |
| // torch.device callable means that it will start returning true |
| // for callable() queries, and that is unexpected. We can always add |
| // this later, so for now, don't actually implement this |
| // THPDevice_call, /* tp_call */ |
| nullptr, /* tp_call */ |
| (reprfunc)THPDevice_str, /* tp_str */ |
| nullptr, /* tp_getattro */ |
| nullptr, /* tp_setattro */ |
| nullptr, /* tp_as_buffer */ |
| Py_TPFLAGS_DEFAULT, /* tp_flags */ |
| nullptr, /* tp_doc */ |
| nullptr, /* tp_traverse */ |
| nullptr, /* tp_clear */ |
| (richcmpfunc)THPDevice_rc, /* tp_richcompare */ |
| 0, /* tp_weaklistoffset */ |
| nullptr, /* tp_iter */ |
| nullptr, /* tp_iternext */ |
| THPDevice_methods, /* tp_methods */ |
| nullptr, /* tp_members */ |
| THPDevice_properties, /* tp_getset */ |
| nullptr, /* tp_base */ |
| nullptr, /* tp_dict */ |
| nullptr, /* tp_descr_get */ |
| nullptr, /* tp_descr_set */ |
| 0, /* tp_dictoffset */ |
| nullptr, /* tp_init */ |
| nullptr, /* tp_alloc */ |
| THPDevice_pynew, /* tp_new */ |
| }; |
| |
| void THPDevice_init(PyObject* module) { |
| if (PyType_Ready(&THPDeviceType) < 0) { |
| throw python_error(); |
| } |
| Py_INCREF(&THPDeviceType); |
| THPUpperModuleOfDevice = module; |
| if (PyModule_AddObject(module, "device", (PyObject*)&THPDeviceType) != 0) { |
| throw python_error(); |
| } |
| } |