| //===- IRModules.cpp - IR Submodules of pybind module ---------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "IRModules.h" |
| |
| #include "Globals.h" |
| #include "PybindUtils.h" |
| |
| #include "mlir-c/Bindings/Python/Interop.h" |
| #include "mlir-c/BuiltinAttributes.h" |
| #include "mlir-c/BuiltinTypes.h" |
| #include "mlir-c/Registration.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include <pybind11/stl.h> |
| |
| namespace py = pybind11; |
| using namespace mlir; |
| using namespace mlir::python; |
| |
| using llvm::SmallVector; |
| using llvm::StringRef; |
| using llvm::Twine; |
| |
| //------------------------------------------------------------------------------ |
| // Docstrings (trivial, non-duplicated docstrings are included inline). |
| //------------------------------------------------------------------------------ |
| |
| static const char kContextParseTypeDocstring[] = |
| R"(Parses the assembly form of a type. |
| |
| Returns a Type object or raises a ValueError if the type cannot be parsed. |
| |
| See also: https://mlir.llvm.org/docs/LangRef/#type-system |
| )"; |
| |
| static const char kContextGetFileLocationDocstring[] = |
| R"(Gets a Location representing a file, line and column)"; |
| |
| static const char kModuleParseDocstring[] = |
| R"(Parses a module's assembly format from a string. |
| |
| Returns a new MlirModule or raises a ValueError if the parsing fails. |
| |
| See also: https://mlir.llvm.org/docs/LangRef/ |
| )"; |
| |
| static const char kOperationCreateDocstring[] = |
| R"(Creates a new operation. |
| |
| Args: |
| name: Operation name (e.g. "dialect.operation"). |
| results: Sequence of Type representing op result types. |
| attributes: Dict of str:Attribute. |
| successors: List of Block for the operation's successors. |
| regions: Number of regions to create. |
| location: A Location object (defaults to resolve from context manager). |
| ip: An InsertionPoint (defaults to resolve from context manager or set to |
| False to disable insertion, even with an insertion point set in the |
| context manager). |
| Returns: |
| A new "detached" Operation object. Detached operations can be added |
| to blocks, which causes them to become "attached." |
| )"; |
| |
| static const char kOperationPrintDocstring[] = |
| R"(Prints the assembly form of the operation to a file like object. |
| |
| Args: |
| file: The file like object to write to. Defaults to sys.stdout. |
| binary: Whether to write bytes (True) or str (False). Defaults to False. |
| large_elements_limit: Whether to elide elements attributes above this |
| number of elements. Defaults to None (no limit). |
| enable_debug_info: Whether to print debug/location information. Defaults |
| to False. |
| pretty_debug_info: Whether to format debug information for easier reading |
| by a human (warning: the result is unparseable). |
| print_generic_op_form: Whether to print the generic assembly forms of all |
| ops. Defaults to False. |
| use_local_Scope: Whether to print in a way that is more optimized for |
| multi-threaded access but may not be consistent with how the overall |
| module prints. |
| )"; |
| |
| static const char kOperationGetAsmDocstring[] = |
| R"(Gets the assembly form of the operation with all options available. |
| |
| Args: |
| binary: Whether to return a bytes (True) or str (False) object. Defaults to |
| False. |
| ... others ...: See the print() method for common keyword arguments for |
| configuring the printout. |
| Returns: |
| Either a bytes or str object, depending on the setting of the 'binary' |
| argument. |
| )"; |
| |
| static const char kOperationStrDunderDocstring[] = |
| R"(Gets the assembly form of the operation with default options. |
| |
| If more advanced control over the assembly formatting or I/O options is needed, |
| use the dedicated print or get_asm method, which supports keyword arguments to |
| customize behavior. |
| )"; |
| |
| static const char kDumpDocstring[] = |
| R"(Dumps a debug representation of the object to stderr.)"; |
| |
| static const char kAppendBlockDocstring[] = |
| R"(Appends a new block, with argument types as positional args. |
| |
| Returns: |
| The created block. |
| )"; |
| |
| static const char kValueDunderStrDocstring[] = |
| R"(Returns the string form of the value. |
| |
| If the value is a block argument, this is the assembly form of its type and the |
| position in the argument list. If the value is an operation result, this is |
| equivalent to printing the operation that produced it. |
| )"; |
| |
| //------------------------------------------------------------------------------ |
| // Utilities. |
| //------------------------------------------------------------------------------ |
| |
| /// Checks whether the given type is an integer or float type. |
| static int mlirTypeIsAIntegerOrFloat(MlirType type) { |
| return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || |
| mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); |
| } |
| |
| static py::object |
| createCustomDialectWrapper(const std::string &dialectNamespace, |
| py::object dialectDescriptor) { |
| auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); |
| if (!dialectClass) { |
| // Use the base class. |
| return py::cast(PyDialect(std::move(dialectDescriptor))); |
| } |
| |
| // Create the custom implementation. |
| return (*dialectClass)(std::move(dialectDescriptor)); |
| } |
| |
| static MlirStringRef toMlirStringRef(const std::string &s) { |
| return mlirStringRefCreate(s.data(), s.size()); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // Collections. |
| //------------------------------------------------------------------------------ |
| |
| namespace { |
| |
| class PyRegionIterator { |
| public: |
| PyRegionIterator(PyOperationRef operation) |
| : operation(std::move(operation)) {} |
| |
| PyRegionIterator &dunderIter() { return *this; } |
| |
| PyRegion dunderNext() { |
| operation->checkValid(); |
| if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { |
| throw py::stop_iteration(); |
| } |
| MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); |
| return PyRegion(operation, region); |
| } |
| |
| static void bind(py::module &m) { |
| py::class_<PyRegionIterator>(m, "RegionIterator") |
| .def("__iter__", &PyRegionIterator::dunderIter) |
| .def("__next__", &PyRegionIterator::dunderNext); |
| } |
| |
| private: |
| PyOperationRef operation; |
| int nextIndex = 0; |
| }; |
| |
| /// Regions of an op are fixed length and indexed numerically so are represented |
| /// with a sequence-like container. |
| class PyRegionList { |
| public: |
| PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} |
| |
| intptr_t dunderLen() { |
| operation->checkValid(); |
| return mlirOperationGetNumRegions(operation->get()); |
| } |
| |
| PyRegion dunderGetItem(intptr_t index) { |
| // dunderLen checks validity. |
| if (index < 0 || index >= dunderLen()) { |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds region"); |
| } |
| MlirRegion region = mlirOperationGetRegion(operation->get(), index); |
| return PyRegion(operation, region); |
| } |
| |
| static void bind(py::module &m) { |
| py::class_<PyRegionList>(m, "ReqionSequence") |
| .def("__len__", &PyRegionList::dunderLen) |
| .def("__getitem__", &PyRegionList::dunderGetItem); |
| } |
| |
| private: |
| PyOperationRef operation; |
| }; |
| |
| class PyBlockIterator { |
| public: |
| PyBlockIterator(PyOperationRef operation, MlirBlock next) |
| : operation(std::move(operation)), next(next) {} |
| |
| PyBlockIterator &dunderIter() { return *this; } |
| |
| PyBlock dunderNext() { |
| operation->checkValid(); |
| if (mlirBlockIsNull(next)) { |
| throw py::stop_iteration(); |
| } |
| |
| PyBlock returnBlock(operation, next); |
| next = mlirBlockGetNextInRegion(next); |
| return returnBlock; |
| } |
| |
| static void bind(py::module &m) { |
| py::class_<PyBlockIterator>(m, "BlockIterator") |
| .def("__iter__", &PyBlockIterator::dunderIter) |
| .def("__next__", &PyBlockIterator::dunderNext); |
| } |
| |
| private: |
| PyOperationRef operation; |
| MlirBlock next; |
| }; |
| |
| /// Blocks are exposed by the C-API as a forward-only linked list. In Python, |
| /// we present them as a more full-featured list-like container but optimize |
| /// it for forward iteration. Blocks are always owned by a region. |
| class PyBlockList { |
| public: |
| PyBlockList(PyOperationRef operation, MlirRegion region) |
| : operation(std::move(operation)), region(region) {} |
| |
| PyBlockIterator dunderIter() { |
| operation->checkValid(); |
| return PyBlockIterator(operation, mlirRegionGetFirstBlock(region)); |
| } |
| |
| intptr_t dunderLen() { |
| operation->checkValid(); |
| intptr_t count = 0; |
| MlirBlock block = mlirRegionGetFirstBlock(region); |
| while (!mlirBlockIsNull(block)) { |
| count += 1; |
| block = mlirBlockGetNextInRegion(block); |
| } |
| return count; |
| } |
| |
| PyBlock dunderGetItem(intptr_t index) { |
| operation->checkValid(); |
| if (index < 0) { |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds block"); |
| } |
| MlirBlock block = mlirRegionGetFirstBlock(region); |
| while (!mlirBlockIsNull(block)) { |
| if (index == 0) { |
| return PyBlock(operation, block); |
| } |
| block = mlirBlockGetNextInRegion(block); |
| index -= 1; |
| } |
| throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); |
| } |
| |
| PyBlock appendBlock(py::args pyArgTypes) { |
| operation->checkValid(); |
| llvm::SmallVector<MlirType, 4> argTypes; |
| argTypes.reserve(pyArgTypes.size()); |
| for (auto &pyArg : pyArgTypes) { |
| argTypes.push_back(pyArg.cast<PyType &>()); |
| } |
| |
| MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); |
| mlirRegionAppendOwnedBlock(region, block); |
| return PyBlock(operation, block); |
| } |
| |
| static void bind(py::module &m) { |
| py::class_<PyBlockList>(m, "BlockList") |
| .def("__getitem__", &PyBlockList::dunderGetItem) |
| .def("__iter__", &PyBlockList::dunderIter) |
| .def("__len__", &PyBlockList::dunderLen) |
| .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); |
| } |
| |
| private: |
| PyOperationRef operation; |
| MlirRegion region; |
| }; |
| |
| class PyOperationIterator { |
| public: |
| PyOperationIterator(PyOperationRef parentOperation, MlirOperation next) |
| : parentOperation(std::move(parentOperation)), next(next) {} |
| |
| PyOperationIterator &dunderIter() { return *this; } |
| |
| py::object dunderNext() { |
| parentOperation->checkValid(); |
| if (mlirOperationIsNull(next)) { |
| throw py::stop_iteration(); |
| } |
| |
| PyOperationRef returnOperation = |
| PyOperation::forOperation(parentOperation->getContext(), next); |
| next = mlirOperationGetNextInBlock(next); |
| return returnOperation->createOpView(); |
| } |
| |
| static void bind(py::module &m) { |
| py::class_<PyOperationIterator>(m, "OperationIterator") |
| .def("__iter__", &PyOperationIterator::dunderIter) |
| .def("__next__", &PyOperationIterator::dunderNext); |
| } |
| |
| private: |
| PyOperationRef parentOperation; |
| MlirOperation next; |
| }; |
| |
| /// Operations are exposed by the C-API as a forward-only linked list. In |
| /// Python, we present them as a more full-featured list-like container but |
| /// optimize it for forward iteration. Iterable operations are always owned |
| /// by a block. |
| class PyOperationList { |
| public: |
| PyOperationList(PyOperationRef parentOperation, MlirBlock block) |
| : parentOperation(std::move(parentOperation)), block(block) {} |
| |
| PyOperationIterator dunderIter() { |
| parentOperation->checkValid(); |
| return PyOperationIterator(parentOperation, |
| mlirBlockGetFirstOperation(block)); |
| } |
| |
| intptr_t dunderLen() { |
| parentOperation->checkValid(); |
| intptr_t count = 0; |
| MlirOperation childOp = mlirBlockGetFirstOperation(block); |
| while (!mlirOperationIsNull(childOp)) { |
| count += 1; |
| childOp = mlirOperationGetNextInBlock(childOp); |
| } |
| return count; |
| } |
| |
| py::object dunderGetItem(intptr_t index) { |
| parentOperation->checkValid(); |
| if (index < 0) { |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds operation"); |
| } |
| MlirOperation childOp = mlirBlockGetFirstOperation(block); |
| while (!mlirOperationIsNull(childOp)) { |
| if (index == 0) { |
| return PyOperation::forOperation(parentOperation->getContext(), childOp) |
| ->createOpView(); |
| } |
| childOp = mlirOperationGetNextInBlock(childOp); |
| index -= 1; |
| } |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds operation"); |
| } |
| |
| static void bind(py::module &m) { |
| py::class_<PyOperationList>(m, "OperationList") |
| .def("__getitem__", &PyOperationList::dunderGetItem) |
| .def("__iter__", &PyOperationList::dunderIter) |
| .def("__len__", &PyOperationList::dunderLen); |
| } |
| |
| private: |
| PyOperationRef parentOperation; |
| MlirBlock block; |
| }; |
| |
| } // namespace |
| |
| //------------------------------------------------------------------------------ |
| // PyMlirContext |
| //------------------------------------------------------------------------------ |
| |
| PyMlirContext::PyMlirContext(MlirContext context) : context(context) { |
| py::gil_scoped_acquire acquire; |
| auto &liveContexts = getLiveContexts(); |
| liveContexts[context.ptr] = this; |
| } |
| |
| PyMlirContext::~PyMlirContext() { |
| // Note that the only public way to construct an instance is via the |
| // forContext method, which always puts the associated handle into |
| // liveContexts. |
| py::gil_scoped_acquire acquire; |
| getLiveContexts().erase(context.ptr); |
| mlirContextDestroy(context); |
| } |
| |
| py::object PyMlirContext::getCapsule() { |
| return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get())); |
| } |
| |
| py::object PyMlirContext::createFromCapsule(py::object capsule) { |
| MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); |
| if (mlirContextIsNull(rawContext)) |
| throw py::error_already_set(); |
| return forContext(rawContext).releaseObject(); |
| } |
| |
| PyMlirContext *PyMlirContext::createNewContextForInit() { |
| MlirContext context = mlirContextCreate(); |
| mlirRegisterAllDialects(context); |
| return new PyMlirContext(context); |
| } |
| |
| PyMlirContextRef PyMlirContext::forContext(MlirContext context) { |
| py::gil_scoped_acquire acquire; |
| auto &liveContexts = getLiveContexts(); |
| auto it = liveContexts.find(context.ptr); |
| if (it == liveContexts.end()) { |
| // Create. |
| PyMlirContext *unownedContextWrapper = new PyMlirContext(context); |
| py::object pyRef = py::cast(unownedContextWrapper); |
| assert(pyRef && "cast to py::object failed"); |
| liveContexts[context.ptr] = unownedContextWrapper; |
| return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); |
| } |
| // Use existing. |
| py::object pyRef = py::cast(it->second); |
| return PyMlirContextRef(it->second, std::move(pyRef)); |
| } |
| |
| PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { |
| static LiveContextMap liveContexts; |
| return liveContexts; |
| } |
| |
| size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } |
| |
| size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } |
| |
| size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } |
| |
| pybind11::object PyMlirContext::contextEnter() { |
| return PyThreadContextEntry::pushContext(*this); |
| } |
| |
| void PyMlirContext::contextExit(pybind11::object excType, |
| pybind11::object excVal, |
| pybind11::object excTb) { |
| PyThreadContextEntry::popContext(*this); |
| } |
| |
| PyMlirContext &DefaultingPyMlirContext::resolve() { |
| PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); |
| if (!context) { |
| throw SetPyError( |
| PyExc_RuntimeError, |
| "An MLIR function requires a Context but none was provided in the call " |
| "or from the surrounding environment. Either pass to the function with " |
| "a 'context=' argument or establish a default using 'with Context():'"); |
| } |
| return *context; |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyThreadContextEntry management |
| //------------------------------------------------------------------------------ |
| |
| std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() { |
| static thread_local std::vector<PyThreadContextEntry> stack; |
| return stack; |
| } |
| |
| PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { |
| auto &stack = getStack(); |
| if (stack.empty()) |
| return nullptr; |
| return &stack.back(); |
| } |
| |
| void PyThreadContextEntry::push(FrameKind frameKind, py::object context, |
| py::object insertionPoint, |
| py::object location) { |
| auto &stack = getStack(); |
| stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), |
| std::move(location)); |
| // If the new stack has more than one entry and the context of the new top |
| // entry matches the previous, copy the insertionPoint and location from the |
| // previous entry if missing from the new top entry. |
| if (stack.size() > 1) { |
| auto &prev = *(stack.rbegin() + 1); |
| auto ¤t = stack.back(); |
| if (current.context.is(prev.context)) { |
| // Default non-context objects from the previous entry. |
| if (!current.insertionPoint) |
| current.insertionPoint = prev.insertionPoint; |
| if (!current.location) |
| current.location = prev.location; |
| } |
| } |
| } |
| |
| PyMlirContext *PyThreadContextEntry::getContext() { |
| if (!context) |
| return nullptr; |
| return py::cast<PyMlirContext *>(context); |
| } |
| |
| PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { |
| if (!insertionPoint) |
| return nullptr; |
| return py::cast<PyInsertionPoint *>(insertionPoint); |
| } |
| |
| PyLocation *PyThreadContextEntry::getLocation() { |
| if (!location) |
| return nullptr; |
| return py::cast<PyLocation *>(location); |
| } |
| |
| PyMlirContext *PyThreadContextEntry::getDefaultContext() { |
| auto *tos = getTopOfStack(); |
| return tos ? tos->getContext() : nullptr; |
| } |
| |
| PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() { |
| auto *tos = getTopOfStack(); |
| return tos ? tos->getInsertionPoint() : nullptr; |
| } |
| |
| PyLocation *PyThreadContextEntry::getDefaultLocation() { |
| auto *tos = getTopOfStack(); |
| return tos ? tos->getLocation() : nullptr; |
| } |
| |
| py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { |
| py::object contextObj = py::cast(context); |
| push(FrameKind::Context, /*context=*/contextObj, |
| /*insertionPoint=*/py::object(), |
| /*location=*/py::object()); |
| return contextObj; |
| } |
| |
| void PyThreadContextEntry::popContext(PyMlirContext &context) { |
| auto &stack = getStack(); |
| if (stack.empty()) |
| throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); |
| auto &tos = stack.back(); |
| if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) |
| throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); |
| stack.pop_back(); |
| } |
| |
| py::object |
| PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { |
| py::object contextObj = |
| insertionPoint.getBlock().getParentOperation()->getContext().getObject(); |
| py::object insertionPointObj = py::cast(insertionPoint); |
| push(FrameKind::InsertionPoint, |
| /*context=*/contextObj, |
| /*insertionPoint=*/insertionPointObj, |
| /*location=*/py::object()); |
| return insertionPointObj; |
| } |
| |
| void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { |
| auto &stack = getStack(); |
| if (stack.empty()) |
| throw SetPyError(PyExc_RuntimeError, |
| "Unbalanced InsertionPoint enter/exit"); |
| auto &tos = stack.back(); |
| if (tos.frameKind != FrameKind::InsertionPoint && |
| tos.getInsertionPoint() != &insertionPoint) |
| throw SetPyError(PyExc_RuntimeError, |
| "Unbalanced InsertionPoint enter/exit"); |
| stack.pop_back(); |
| } |
| |
| py::object PyThreadContextEntry::pushLocation(PyLocation &location) { |
| py::object contextObj = location.getContext().getObject(); |
| py::object locationObj = py::cast(location); |
| push(FrameKind::Location, /*context=*/contextObj, |
| /*insertionPoint=*/py::object(), |
| /*location=*/locationObj); |
| return locationObj; |
| } |
| |
| void PyThreadContextEntry::popLocation(PyLocation &location) { |
| auto &stack = getStack(); |
| if (stack.empty()) |
| throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); |
| auto &tos = stack.back(); |
| if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) |
| throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); |
| stack.pop_back(); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyDialect, PyDialectDescriptor, PyDialects |
| //------------------------------------------------------------------------------ |
| |
| MlirDialect PyDialects::getDialectForKey(const std::string &key, |
| bool attrError) { |
| // If the "std" dialect was asked for, substitute the empty namespace :( |
| static const std::string emptyKey; |
| const std::string *canonKey = key == "std" ? &emptyKey : &key; |
| MlirDialect dialect = mlirContextGetOrLoadDialect( |
| getContext()->get(), {canonKey->data(), canonKey->size()}); |
| if (mlirDialectIsNull(dialect)) { |
| throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, |
| Twine("Dialect '") + key + "' not found"); |
| } |
| return dialect; |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyLocation |
| //------------------------------------------------------------------------------ |
| |
| py::object PyLocation::getCapsule() { |
| return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this)); |
| } |
| |
| PyLocation PyLocation::createFromCapsule(py::object capsule) { |
| MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); |
| if (mlirLocationIsNull(rawLoc)) |
| throw py::error_already_set(); |
| return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), |
| rawLoc); |
| } |
| |
| py::object PyLocation::contextEnter() { |
| return PyThreadContextEntry::pushLocation(*this); |
| } |
| |
| void PyLocation::contextExit(py::object excType, py::object excVal, |
| py::object excTb) { |
| PyThreadContextEntry::popLocation(*this); |
| } |
| |
| PyLocation &DefaultingPyLocation::resolve() { |
| auto *location = PyThreadContextEntry::getDefaultLocation(); |
| if (!location) { |
| throw SetPyError( |
| PyExc_RuntimeError, |
| "An MLIR function requires a Location but none was provided in the " |
| "call or from the surrounding environment. Either pass to the function " |
| "with a 'loc=' argument or establish a default using 'with loc:'"); |
| } |
| return *location; |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyModule |
| //------------------------------------------------------------------------------ |
| |
| PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) |
| : BaseContextObject(std::move(contextRef)), module(module) {} |
| |
| PyModule::~PyModule() { |
| py::gil_scoped_acquire acquire; |
| auto &liveModules = getContext()->liveModules; |
| assert(liveModules.count(module.ptr) == 1 && |
| "destroying module not in live map"); |
| liveModules.erase(module.ptr); |
| mlirModuleDestroy(module); |
| } |
| |
| PyModuleRef PyModule::forModule(MlirModule module) { |
| MlirContext context = mlirModuleGetContext(module); |
| PyMlirContextRef contextRef = PyMlirContext::forContext(context); |
| |
| py::gil_scoped_acquire acquire; |
| auto &liveModules = contextRef->liveModules; |
| auto it = liveModules.find(module.ptr); |
| if (it == liveModules.end()) { |
| // Create. |
| PyModule *unownedModule = new PyModule(std::move(contextRef), module); |
| // Note that the default return value policy on cast is automatic_reference, |
| // which does not take ownership (delete will not be called). |
| // Just be explicit. |
| py::object pyRef = |
| py::cast(unownedModule, py::return_value_policy::take_ownership); |
| unownedModule->handle = pyRef; |
| liveModules[module.ptr] = |
| std::make_pair(unownedModule->handle, unownedModule); |
| return PyModuleRef(unownedModule, std::move(pyRef)); |
| } |
| // Use existing. |
| PyModule *existing = it->second.second; |
| py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); |
| return PyModuleRef(existing, std::move(pyRef)); |
| } |
| |
| py::object PyModule::createFromCapsule(py::object capsule) { |
| MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); |
| if (mlirModuleIsNull(rawModule)) |
| throw py::error_already_set(); |
| return forModule(rawModule).releaseObject(); |
| } |
| |
| py::object PyModule::getCapsule() { |
| return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get())); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyOperation |
| //------------------------------------------------------------------------------ |
| |
| PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) |
| : BaseContextObject(std::move(contextRef)), operation(operation) {} |
| |
| PyOperation::~PyOperation() { |
| auto &liveOperations = getContext()->liveOperations; |
| assert(liveOperations.count(operation.ptr) == 1 && |
| "destroying operation not in live map"); |
| liveOperations.erase(operation.ptr); |
| if (!isAttached()) { |
| mlirOperationDestroy(operation); |
| } |
| } |
| |
| PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, |
| MlirOperation operation, |
| py::object parentKeepAlive) { |
| auto &liveOperations = contextRef->liveOperations; |
| // Create. |
| PyOperation *unownedOperation = |
| new PyOperation(std::move(contextRef), operation); |
| // Note that the default return value policy on cast is automatic_reference, |
| // which does not take ownership (delete will not be called). |
| // Just be explicit. |
| py::object pyRef = |
| py::cast(unownedOperation, py::return_value_policy::take_ownership); |
| unownedOperation->handle = pyRef; |
| if (parentKeepAlive) { |
| unownedOperation->parentKeepAlive = std::move(parentKeepAlive); |
| } |
| liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); |
| return PyOperationRef(unownedOperation, std::move(pyRef)); |
| } |
| |
| PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, |
| MlirOperation operation, |
| py::object parentKeepAlive) { |
| auto &liveOperations = contextRef->liveOperations; |
| auto it = liveOperations.find(operation.ptr); |
| if (it == liveOperations.end()) { |
| // Create. |
| return createInstance(std::move(contextRef), operation, |
| std::move(parentKeepAlive)); |
| } |
| // Use existing. |
| PyOperation *existing = it->second.second; |
| py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first); |
| return PyOperationRef(existing, std::move(pyRef)); |
| } |
| |
| PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, |
| MlirOperation operation, |
| py::object parentKeepAlive) { |
| auto &liveOperations = contextRef->liveOperations; |
| assert(liveOperations.count(operation.ptr) == 0 && |
| "cannot create detached operation that already exists"); |
| (void)liveOperations; |
| |
| PyOperationRef created = createInstance(std::move(contextRef), operation, |
| std::move(parentKeepAlive)); |
| created->attached = false; |
| return created; |
| } |
| |
| void PyOperation::checkValid() const { |
| if (!valid) { |
| throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); |
| } |
| } |
| |
| void PyOperationBase::print(py::object fileObject, bool binary, |
| llvm::Optional<int64_t> largeElementsLimit, |
| bool enableDebugInfo, bool prettyDebugInfo, |
| bool printGenericOpForm, bool useLocalScope) { |
| PyOperation &operation = getOperation(); |
| operation.checkValid(); |
| if (fileObject.is_none()) |
| fileObject = py::module::import("sys").attr("stdout"); |
| |
| if (!printGenericOpForm && !mlirOperationVerify(operation)) { |
| fileObject.attr("write")("// Verification failed, printing generic form\n"); |
| printGenericOpForm = true; |
| } |
| |
| MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); |
| if (largeElementsLimit) |
| mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); |
| if (enableDebugInfo) |
| mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); |
| if (printGenericOpForm) |
| mlirOpPrintingFlagsPrintGenericOpForm(flags); |
| |
| PyFileAccumulator accum(fileObject, binary); |
| py::gil_scoped_release(); |
| mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), |
| accum.getUserData()); |
| mlirOpPrintingFlagsDestroy(flags); |
| } |
| |
| py::object PyOperationBase::getAsm(bool binary, |
| llvm::Optional<int64_t> largeElementsLimit, |
| bool enableDebugInfo, bool prettyDebugInfo, |
| bool printGenericOpForm, |
| bool useLocalScope) { |
| py::object fileObject; |
| if (binary) { |
| fileObject = py::module::import("io").attr("BytesIO")(); |
| } else { |
| fileObject = py::module::import("io").attr("StringIO")(); |
| } |
| print(fileObject, /*binary=*/binary, |
| /*largeElementsLimit=*/largeElementsLimit, |
| /*enableDebugInfo=*/enableDebugInfo, |
| /*prettyDebugInfo=*/prettyDebugInfo, |
| /*printGenericOpForm=*/printGenericOpForm, |
| /*useLocalScope=*/useLocalScope); |
| |
| return fileObject.attr("getvalue")(); |
| } |
| |
| PyOperationRef PyOperation::getParentOperation() { |
| if (!isAttached()) |
| throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); |
| MlirOperation operation = mlirOperationGetParentOperation(get()); |
| if (mlirOperationIsNull(operation)) |
| throw SetPyError(PyExc_ValueError, "Operation has no parent."); |
| return PyOperation::forOperation(getContext(), operation); |
| } |
| |
| PyBlock PyOperation::getBlock() { |
| PyOperationRef parentOperation = getParentOperation(); |
| MlirBlock block = mlirOperationGetBlock(get()); |
| assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); |
| return PyBlock{std::move(parentOperation), block}; |
| } |
| |
| py::object PyOperation::create( |
| std::string name, llvm::Optional<std::vector<PyValue *>> operands, |
| llvm::Optional<std::vector<PyType *>> results, |
| llvm::Optional<py::dict> attributes, |
| llvm::Optional<std::vector<PyBlock *>> successors, int regions, |
| DefaultingPyLocation location, py::object maybeIp) { |
| llvm::SmallVector<MlirValue, 4> mlirOperands; |
| llvm::SmallVector<MlirType, 4> mlirResults; |
| llvm::SmallVector<MlirBlock, 4> mlirSuccessors; |
| llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes; |
| |
| // General parameter validation. |
| if (regions < 0) |
| throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); |
| |
| // Unpack/validate operands. |
| if (operands) { |
| mlirOperands.reserve(operands->size()); |
| for (PyValue *operand : *operands) { |
| if (!operand) |
| throw SetPyError(PyExc_ValueError, "operand value cannot be None"); |
| mlirOperands.push_back(operand->get()); |
| } |
| } |
| |
| // Unpack/validate results. |
| if (results) { |
| mlirResults.reserve(results->size()); |
| for (PyType *result : *results) { |
| // TODO: Verify result type originate from the same context. |
| if (!result) |
| throw SetPyError(PyExc_ValueError, "result type cannot be None"); |
| mlirResults.push_back(*result); |
| } |
| } |
| // Unpack/validate attributes. |
| if (attributes) { |
| mlirAttributes.reserve(attributes->size()); |
| for (auto &it : *attributes) { |
| std::string key; |
| try { |
| key = it.first.cast<std::string>(); |
| } catch (py::cast_error &err) { |
| std::string msg = "Invalid attribute key (not a string) when " |
| "attempting to create the operation \"" + |
| name + "\" (" + err.what() + ")"; |
| throw py::cast_error(msg); |
| } |
| try { |
| auto &attribute = it.second.cast<PyAttribute &>(); |
| // TODO: Verify attribute originates from the same context. |
| mlirAttributes.emplace_back(std::move(key), attribute); |
| } catch (py::reference_cast_error &) { |
| // This exception seems thrown when the value is "None". |
| std::string msg = |
| "Found an invalid (`None`?) attribute value for the key \"" + key + |
| "\" when attempting to create the operation \"" + name + "\""; |
| throw py::cast_error(msg); |
| } catch (py::cast_error &err) { |
| std::string msg = "Invalid attribute value for the key \"" + key + |
| "\" when attempting to create the operation \"" + |
| name + "\" (" + err.what() + ")"; |
| throw py::cast_error(msg); |
| } |
| } |
| } |
| // Unpack/validate successors. |
| if (successors) { |
| llvm::SmallVector<MlirBlock, 4> mlirSuccessors; |
| mlirSuccessors.reserve(successors->size()); |
| for (auto *successor : *successors) { |
| // TODO: Verify successor originate from the same context. |
| if (!successor) |
| throw SetPyError(PyExc_ValueError, "successor block cannot be None"); |
| mlirSuccessors.push_back(successor->get()); |
| } |
| } |
| |
| // Apply unpacked/validated to the operation state. Beyond this |
| // point, exceptions cannot be thrown or else the state will leak. |
| MlirOperationState state = |
| mlirOperationStateGet(toMlirStringRef(name), location); |
| if (!mlirOperands.empty()) |
| mlirOperationStateAddOperands(&state, mlirOperands.size(), |
| mlirOperands.data()); |
| if (!mlirResults.empty()) |
| mlirOperationStateAddResults(&state, mlirResults.size(), |
| mlirResults.data()); |
| if (!mlirAttributes.empty()) { |
| // Note that the attribute names directly reference bytes in |
| // mlirAttributes, so that vector must not be changed from here |
| // on. |
| llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes; |
| mlirNamedAttributes.reserve(mlirAttributes.size()); |
| for (auto &it : mlirAttributes) |
| mlirNamedAttributes.push_back( |
| mlirNamedAttributeGet(toMlirStringRef(it.first), it.second)); |
| mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(), |
| mlirNamedAttributes.data()); |
| } |
| if (!mlirSuccessors.empty()) |
| mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(), |
| mlirSuccessors.data()); |
| if (regions) { |
| llvm::SmallVector<MlirRegion, 4> mlirRegions; |
| mlirRegions.resize(regions); |
| for (int i = 0; i < regions; ++i) |
| mlirRegions[i] = mlirRegionCreate(); |
| mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(), |
| mlirRegions.data()); |
| } |
| |
| // Construct the operation. |
| MlirOperation operation = mlirOperationCreate(&state); |
| PyOperationRef created = |
| PyOperation::createDetached(location->getContext(), operation); |
| |
| // InsertPoint active? |
| if (!maybeIp.is(py::cast(false))) { |
| PyInsertionPoint *ip; |
| if (maybeIp.is_none()) { |
| ip = PyThreadContextEntry::getDefaultInsertionPoint(); |
| } else { |
| ip = py::cast<PyInsertionPoint *>(maybeIp); |
| } |
| if (ip) |
| ip->insert(*created.get()); |
| } |
| |
| return created->createOpView(); |
| } |
| |
| py::object PyOperation::createOpView() { |
| MlirIdentifier ident = mlirOperationGetName(get()); |
| MlirStringRef identStr = mlirIdentifierStr(ident); |
| auto opViewClass = PyGlobals::get().lookupRawOpViewClass( |
| StringRef(identStr.data, identStr.length)); |
| if (opViewClass) |
| return (*opViewClass)(getRef().getObject()); |
| return py::cast(PyOpView(getRef().getObject())); |
| } |
| |
| PyOpView::PyOpView(py::object operationObject) |
| // Casting through the PyOperationBase base-class and then back to the |
| // Operation lets us accept any PyOperationBase subclass. |
| : operation(py::cast<PyOperationBase &>(operationObject).getOperation()), |
| operationObject(operation.getRef().getObject()) {} |
| |
| py::object PyOpView::createRawSubclass(py::object userClass) { |
| // This is... a little gross. The typical pattern is to have a pure python |
| // class that extends OpView like: |
| // class AddFOp(_cext.ir.OpView): |
| // def __init__(self, loc, lhs, rhs): |
| // operation = loc.context.create_operation( |
| // "addf", lhs, rhs, results=[lhs.type]) |
| // super().__init__(operation) |
| // |
| // I.e. The goal of the user facing type is to provide a nice constructor |
| // that has complete freedom for the op under construction. This is at odds |
| // with our other desire to sometimes create this object by just passing an |
| // operation (to initialize the base class). We could do *arg and **kwargs |
| // munging to try to make it work, but instead, we synthesize a new class |
| // on the fly which extends this user class (AddFOp in this example) and |
| // *give it* the base class's __init__ method, thus bypassing the |
| // intermediate subclass's __init__ method entirely. While slightly, |
| // underhanded, this is safe/legal because the type hierarchy has not changed |
| // (we just added a new leaf) and we aren't mucking around with __new__. |
| // Typically, this new class will be stored on the original as "_Raw" and will |
| // be used for casts and other things that need a variant of the class that |
| // is initialized purely from an operation. |
| py::object parentMetaclass = |
| py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type); |
| py::dict attributes; |
| // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from |
| // now. |
| // auto opViewType = py::type::of<PyOpView>(); |
| auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); |
| attributes["__init__"] = opViewType.attr("__init__"); |
| py::str origName = userClass.attr("__name__"); |
| py::str newName = py::str("_") + origName; |
| return parentMetaclass(newName, py::make_tuple(userClass), attributes); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyInsertionPoint. |
| //------------------------------------------------------------------------------ |
| |
| PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {} |
| |
| PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) |
| : refOperation(beforeOperationBase.getOperation().getRef()), |
| block((*refOperation)->getBlock()) {} |
| |
| void PyInsertionPoint::insert(PyOperationBase &operationBase) { |
| PyOperation &operation = operationBase.getOperation(); |
| if (operation.isAttached()) |
| throw SetPyError(PyExc_ValueError, |
| "Attempt to insert operation that is already attached"); |
| block.getParentOperation()->checkValid(); |
| MlirOperation beforeOp = {nullptr}; |
| if (refOperation) { |
| // Insert before operation. |
| (*refOperation)->checkValid(); |
| beforeOp = (*refOperation)->get(); |
| } |
| mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation); |
| operation.setAttached(); |
| } |
| |
| PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { |
| MlirOperation firstOp = mlirBlockGetFirstOperation(block.get()); |
| if (mlirOperationIsNull(firstOp)) { |
| // Just insert at end. |
| return PyInsertionPoint(block); |
| } |
| |
| // Insert before first op. |
| PyOperationRef firstOpRef = PyOperation::forOperation( |
| block.getParentOperation()->getContext(), firstOp); |
| return PyInsertionPoint{block, std::move(firstOpRef)}; |
| } |
| |
| PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { |
| MlirOperation terminator = mlirBlockGetTerminator(block.get()); |
| if (mlirOperationIsNull(terminator)) |
| throw SetPyError(PyExc_ValueError, "Block has no terminator"); |
| PyOperationRef terminatorOpRef = PyOperation::forOperation( |
| block.getParentOperation()->getContext(), terminator); |
| return PyInsertionPoint{block, std::move(terminatorOpRef)}; |
| } |
| |
| py::object PyInsertionPoint::contextEnter() { |
| return PyThreadContextEntry::pushInsertionPoint(*this); |
| } |
| |
| void PyInsertionPoint::contextExit(pybind11::object excType, |
| pybind11::object excVal, |
| pybind11::object excTb) { |
| PyThreadContextEntry::popInsertionPoint(*this); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyAttribute. |
| //------------------------------------------------------------------------------ |
| |
| bool PyAttribute::operator==(const PyAttribute &other) { |
| return mlirAttributeEqual(attr, other.attr); |
| } |
| |
| py::object PyAttribute::getCapsule() { |
| return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this)); |
| } |
| |
| PyAttribute PyAttribute::createFromCapsule(py::object capsule) { |
| MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); |
| if (mlirAttributeIsNull(rawAttr)) |
| throw py::error_already_set(); |
| return PyAttribute( |
| PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyNamedAttribute. |
| //------------------------------------------------------------------------------ |
| |
| PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) |
| : ownedName(new std::string(std::move(ownedName))) { |
| namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyType. |
| //------------------------------------------------------------------------------ |
| |
| bool PyType::operator==(const PyType &other) { |
| return mlirTypeEqual(type, other.type); |
| } |
| |
| py::object PyType::getCapsule() { |
| return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this)); |
| } |
| |
| PyType PyType::createFromCapsule(py::object capsule) { |
| MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); |
| if (mlirTypeIsNull(rawType)) |
| throw py::error_already_set(); |
| return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), |
| rawType); |
| } |
| |
| //------------------------------------------------------------------------------ |
| // PyValue and subclases. |
| //------------------------------------------------------------------------------ |
| |
| namespace { |
| /// CRTP base class for Python MLIR values that subclass Value and should be |
| /// castable from it. The value hierarchy is one level deep and is not supposed |
| /// to accommodate other levels unless core MLIR changes. |
| template <typename DerivedTy> |
| class PyConcreteValue : public PyValue { |
| public: |
| // Derived classes must define statics for: |
| // IsAFunctionTy isaFunction |
| // const char *pyClassName |
| // and redefine bindDerived. |
| using ClassTy = py::class_<DerivedTy, PyValue>; |
| using IsAFunctionTy = bool (*)(MlirValue); |
| |
| PyConcreteValue() = default; |
| PyConcreteValue(PyOperationRef operationRef, MlirValue value) |
| : PyValue(operationRef, value) {} |
| PyConcreteValue(PyValue &orig) |
| : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} |
| |
| /// Attempts to cast the original value to the derived type and throws on |
| /// type mismatches. |
| static MlirValue castFrom(PyValue &orig) { |
| if (!DerivedTy::isaFunction(orig.get())) { |
| auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); |
| throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + |
| DerivedTy::pyClassName + |
| " (from " + origRepr + ")"); |
| } |
| return orig.get(); |
| } |
| |
| /// Binds the Python module objects to functions of this class. |
| static void bind(py::module &m) { |
| auto cls = ClassTy(m, DerivedTy::pyClassName); |
| cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>()); |
| DerivedTy::bindDerived(cls); |
| } |
| |
| /// Implemented by derived classes to add methods to the Python subclass. |
| static void bindDerived(ClassTy &m) {} |
| }; |
| |
| /// Python wrapper for MlirBlockArgument. |
| class PyBlockArgument : public PyConcreteValue<PyBlockArgument> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument; |
| static constexpr const char *pyClassName = "BlockArgument"; |
| using PyConcreteValue::PyConcreteValue; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_property_readonly("owner", [](PyBlockArgument &self) { |
| return PyBlock(self.getParentOperation(), |
| mlirBlockArgumentGetOwner(self.get())); |
| }); |
| c.def_property_readonly("arg_number", [](PyBlockArgument &self) { |
| return mlirBlockArgumentGetArgNumber(self.get()); |
| }); |
| c.def("set_type", [](PyBlockArgument &self, PyType type) { |
| return mlirBlockArgumentSetType(self.get(), type); |
| }); |
| } |
| }; |
| |
| /// Python wrapper for MlirOpResult. |
| class PyOpResult : public PyConcreteValue<PyOpResult> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; |
| static constexpr const char *pyClassName = "OpResult"; |
| using PyConcreteValue::PyConcreteValue; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_property_readonly("owner", [](PyOpResult &self) { |
| assert( |
| mlirOperationEqual(self.getParentOperation()->get(), |
| mlirOpResultGetOwner(self.get())) && |
| "expected the owner of the value in Python to match that in the IR"); |
| return self.getParentOperation(); |
| }); |
| c.def_property_readonly("result_number", [](PyOpResult &self) { |
| return mlirOpResultGetResultNumber(self.get()); |
| }); |
| } |
| }; |
| |
| /// A list of block arguments. Internally, these are stored as consecutive |
| /// elements, random access is cheap. The argument list is associated with the |
| /// operation that contains the block (detached blocks are not allowed in |
| /// Python bindings) and extends its lifetime. |
| class PyBlockArgumentList { |
| public: |
| PyBlockArgumentList(PyOperationRef operation, MlirBlock block) |
| : operation(std::move(operation)), block(block) {} |
| |
| /// Returns the length of the block argument list. |
| intptr_t dunderLen() { |
| operation->checkValid(); |
| return mlirBlockGetNumArguments(block); |
| } |
| |
| /// Returns `index`-th element of the block argument list. |
| PyBlockArgument dunderGetItem(intptr_t index) { |
| if (index < 0 || index >= dunderLen()) { |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds region"); |
| } |
| PyValue value(operation, mlirBlockGetArgument(block, index)); |
| return PyBlockArgument(value); |
| } |
| |
| /// Defines a Python class in the bindings. |
| static void bind(py::module &m) { |
| py::class_<PyBlockArgumentList>(m, "BlockArgumentList") |
| .def("__len__", &PyBlockArgumentList::dunderLen) |
| .def("__getitem__", &PyBlockArgumentList::dunderGetItem); |
| } |
| |
| private: |
| PyOperationRef operation; |
| MlirBlock block; |
| }; |
| |
| /// A list of operation operands. Internally, these are stored as consecutive |
| /// elements, random access is cheap. The result list is associated with the |
| /// operation whose results these are, and extends the lifetime of this |
| /// operation. |
| class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> { |
| public: |
| static constexpr const char *pyClassName = "OpOperandList"; |
| |
| PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, |
| intptr_t length = -1, intptr_t step = 1) |
| : Sliceable(startIndex, |
| length == -1 ? mlirOperationGetNumOperands(operation->get()) |
| : length, |
| step), |
| operation(operation) {} |
| |
| intptr_t getNumElements() { |
| operation->checkValid(); |
| return mlirOperationGetNumOperands(operation->get()); |
| } |
| |
| PyValue getElement(intptr_t pos) { |
| return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); |
| } |
| |
| PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
| return PyOpOperandList(operation, startIndex, length, step); |
| } |
| |
| private: |
| PyOperationRef operation; |
| }; |
| |
| /// A list of operation results. Internally, these are stored as consecutive |
| /// elements, random access is cheap. The result list is associated with the |
| /// operation whose results these are, and extends the lifetime of this |
| /// operation. |
| class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> { |
| public: |
| static constexpr const char *pyClassName = "OpResultList"; |
| |
| PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, |
| intptr_t length = -1, intptr_t step = 1) |
| : Sliceable(startIndex, |
| length == -1 ? mlirOperationGetNumResults(operation->get()) |
| : length, |
| step), |
| operation(operation) {} |
| |
| intptr_t getNumElements() { |
| operation->checkValid(); |
| return mlirOperationGetNumResults(operation->get()); |
| } |
| |
| PyOpResult getElement(intptr_t index) { |
| PyValue value(operation, mlirOperationGetResult(operation->get(), index)); |
| return PyOpResult(value); |
| } |
| |
| PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { |
| return PyOpResultList(operation, startIndex, length, step); |
| } |
| |
| private: |
| PyOperationRef operation; |
| }; |
| |
| /// A list of operation attributes. Can be indexed by name, producing |
| /// attributes, or by index, producing named attributes. |
| class PyOpAttributeMap { |
| public: |
| PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} |
| |
| PyAttribute dunderGetItemNamed(const std::string &name) { |
| MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), |
| toMlirStringRef(name)); |
| if (mlirAttributeIsNull(attr)) { |
| throw SetPyError(PyExc_KeyError, |
| "attempt to access a non-existent attribute"); |
| } |
| return PyAttribute(operation->getContext(), attr); |
| } |
| |
| PyNamedAttribute dunderGetItemIndexed(intptr_t index) { |
| if (index < 0 || index >= dunderLen()) { |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds attribute"); |
| } |
| MlirNamedAttribute namedAttr = |
| mlirOperationGetAttribute(operation->get(), index); |
| return PyNamedAttribute(namedAttr.attribute, |
| std::string(namedAttr.name.data)); |
| } |
| |
| void dunderSetItem(const std::string &name, PyAttribute attr) { |
| mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), |
| attr); |
| } |
| |
| void dunderDelItem(const std::string &name) { |
| int removed = mlirOperationRemoveAttributeByName(operation->get(), |
| toMlirStringRef(name)); |
| if (!removed) |
| throw SetPyError(PyExc_KeyError, |
| "attempt to delete a non-existent attribute"); |
| } |
| |
| intptr_t dunderLen() { |
| return mlirOperationGetNumAttributes(operation->get()); |
| } |
| |
| bool dunderContains(const std::string &name) { |
| return !mlirAttributeIsNull(mlirOperationGetAttributeByName( |
| operation->get(), toMlirStringRef(name))); |
| } |
| |
| static void bind(py::module &m) { |
| py::class_<PyOpAttributeMap>(m, "OpAttributeMap") |
| .def("__contains__", &PyOpAttributeMap::dunderContains) |
| .def("__len__", &PyOpAttributeMap::dunderLen) |
| .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) |
| .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed) |
| .def("__setitem__", &PyOpAttributeMap::dunderSetItem) |
| .def("__delitem__", &PyOpAttributeMap::dunderDelItem); |
| } |
| |
| private: |
| PyOperationRef operation; |
| }; |
| |
| } // end namespace |
| |
| //------------------------------------------------------------------------------ |
| // Builtin attribute subclasses. |
| //------------------------------------------------------------------------------ |
| |
| namespace { |
| |
| /// CRTP base classes for Python attributes that subclass Attribute and should |
| /// be castable from it (i.e. via something like StringAttr(attr)). |
| /// By default, attribute class hierarchies are one level deep (i.e. a |
| /// concrete attribute class extends PyAttribute); however, intermediate |
| /// python-visible base classes can be modeled by specifying a BaseTy. |
| template <typename DerivedTy, typename BaseTy = PyAttribute> |
| class PyConcreteAttribute : public BaseTy { |
| public: |
| // Derived classes must define statics for: |
| // IsAFunctionTy isaFunction |
| // const char *pyClassName |
| using ClassTy = py::class_<DerivedTy, BaseTy>; |
| using IsAFunctionTy = bool (*)(MlirAttribute); |
| |
| PyConcreteAttribute() = default; |
| PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) |
| : BaseTy(std::move(contextRef), attr) {} |
| PyConcreteAttribute(PyAttribute &orig) |
| : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} |
| |
| static MlirAttribute castFrom(PyAttribute &orig) { |
| if (!DerivedTy::isaFunction(orig)) { |
| auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); |
| throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + |
| DerivedTy::pyClassName + |
| " (from " + origRepr + ")"); |
| } |
| return orig; |
| } |
| |
| static void bind(py::module &m) { |
| auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); |
| cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>()); |
| DerivedTy::bindDerived(cls); |
| } |
| |
| /// Implemented by derived classes to add methods to the Python subclass. |
| static void bindDerived(ClassTy &m) {} |
| }; |
| |
| /// Float Point Attribute subclass - FloatAttr. |
| class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; |
| static constexpr const char *pyClassName = "FloatAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](PyType &type, double value, DefaultingPyLocation loc) { |
| MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirAttributeIsNull(attr)) { |
| throw SetPyError(PyExc_ValueError, |
| Twine("invalid '") + |
| py::repr(py::cast(type)).cast<std::string>() + |
| "' and expected floating point type."); |
| } |
| return PyFloatAttribute(type.getContext(), attr); |
| }, |
| py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), |
| "Gets an uniqued float point attribute associated to a type"); |
| c.def_static( |
| "get_f32", |
| [](double value, DefaultingPyMlirContext context) { |
| MlirAttribute attr = mlirFloatAttrDoubleGet( |
| context->get(), mlirF32TypeGet(context->get()), value); |
| return PyFloatAttribute(context->getRef(), attr); |
| }, |
| py::arg("value"), py::arg("context") = py::none(), |
| "Gets an uniqued float point attribute associated to a f32 type"); |
| c.def_static( |
| "get_f64", |
| [](double value, DefaultingPyMlirContext context) { |
| MlirAttribute attr = mlirFloatAttrDoubleGet( |
| context->get(), mlirF64TypeGet(context->get()), value); |
| return PyFloatAttribute(context->getRef(), attr); |
| }, |
| py::arg("value"), py::arg("context") = py::none(), |
| "Gets an uniqued float point attribute associated to a f64 type"); |
| c.def_property_readonly( |
| "value", |
| [](PyFloatAttribute &self) { |
| return mlirFloatAttrGetValueDouble(self); |
| }, |
| "Returns the value of the float point attribute"); |
| } |
| }; |
| |
| /// Integer Attribute subclass - IntegerAttr. |
| class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; |
| static constexpr const char *pyClassName = "IntegerAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](PyType &type, int64_t value) { |
| MlirAttribute attr = mlirIntegerAttrGet(type, value); |
| return PyIntegerAttribute(type.getContext(), attr); |
| }, |
| py::arg("type"), py::arg("value"), |
| "Gets an uniqued integer attribute associated to a type"); |
| c.def_property_readonly( |
| "value", |
| [](PyIntegerAttribute &self) { |
| return mlirIntegerAttrGetValueInt(self); |
| }, |
| "Returns the value of the integer attribute"); |
| } |
| }; |
| |
| /// Bool Attribute subclass - BoolAttr. |
| class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; |
| static constexpr const char *pyClassName = "BoolAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](bool value, DefaultingPyMlirContext context) { |
| MlirAttribute attr = mlirBoolAttrGet(context->get(), value); |
| return PyBoolAttribute(context->getRef(), attr); |
| }, |
| py::arg("value"), py::arg("context") = py::none(), |
| "Gets an uniqued bool attribute"); |
| c.def_property_readonly( |
| "value", |
| [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, |
| "Returns the value of the bool attribute"); |
| } |
| }; |
| |
| class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; |
| static constexpr const char *pyClassName = "StringAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](std::string value, DefaultingPyMlirContext context) { |
| MlirAttribute attr = |
| mlirStringAttrGet(context->get(), toMlirStringRef(value)); |
| return PyStringAttribute(context->getRef(), attr); |
| }, |
| py::arg("value"), py::arg("context") = py::none(), |
| "Gets a uniqued string attribute"); |
| c.def_static( |
| "get_typed", |
| [](PyType &type, std::string value) { |
| MlirAttribute attr = |
| mlirStringAttrTypedGet(type, toMlirStringRef(value)); |
| return PyStringAttribute(type.getContext(), attr); |
| }, |
| |
| "Gets a uniqued string attribute associated to a type"); |
| c.def_property_readonly( |
| "value", |
| [](PyStringAttribute &self) { |
| MlirStringRef stringRef = mlirStringAttrGetValue(self); |
| return py::str(stringRef.data, stringRef.length); |
| }, |
| "Returns the value of the string attribute"); |
| } |
| }; |
| |
| // TODO: Support construction of bool elements. |
| // TODO: Support construction of string elements. |
| class PyDenseElementsAttribute |
| : public PyConcreteAttribute<PyDenseElementsAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; |
| static constexpr const char *pyClassName = "DenseElementsAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| static PyDenseElementsAttribute |
| getFromBuffer(py::buffer array, bool signless, |
| DefaultingPyMlirContext contextWrapper) { |
| // Request a contiguous view. In exotic cases, this will cause a copy. |
| int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; |
| Py_buffer *view = new Py_buffer(); |
| if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { |
| delete view; |
| throw py::error_already_set(); |
| } |
| py::buffer_info arrayInfo(view); |
| |
| MlirContext context = contextWrapper->get(); |
| // Switch on the types that can be bulk loaded between the Python and |
| // MLIR-C APIs. |
| // See: https://docs.python.org/3/library/struct.html#format-characters |
| if (arrayInfo.format == "f") { |
| // f32 |
| assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); |
| return PyDenseElementsAttribute( |
| contextWrapper->getRef(), |
| bulkLoad(context, mlirDenseElementsAttrFloatGet, |
| mlirF32TypeGet(context), arrayInfo)); |
| } else if (arrayInfo.format == "d") { |
| // f64 |
| assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); |
| return PyDenseElementsAttribute( |
| contextWrapper->getRef(), |
| bulkLoad(context, mlirDenseElementsAttrDoubleGet, |
| mlirF64TypeGet(context), arrayInfo)); |
| } else if (isSignedIntegerFormat(arrayInfo.format)) { |
| if (arrayInfo.itemsize == 4) { |
| // i32 |
| MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) |
| : mlirIntegerTypeSignedGet(context, 32); |
| return PyDenseElementsAttribute(contextWrapper->getRef(), |
| bulkLoad(context, |
| mlirDenseElementsAttrInt32Get, |
| elementType, arrayInfo)); |
| } else if (arrayInfo.itemsize == 8) { |
| // i64 |
| MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) |
| : mlirIntegerTypeSignedGet(context, 64); |
| return PyDenseElementsAttribute(contextWrapper->getRef(), |
| bulkLoad(context, |
| mlirDenseElementsAttrInt64Get, |
| elementType, arrayInfo)); |
| } |
| } else if (isUnsignedIntegerFormat(arrayInfo.format)) { |
| if (arrayInfo.itemsize == 4) { |
| // unsigned i32 |
| MlirType elementType = signless |
| ? mlirIntegerTypeGet(context, 32) |
| : mlirIntegerTypeUnsignedGet(context, 32); |
| return PyDenseElementsAttribute(contextWrapper->getRef(), |
| bulkLoad(context, |
| mlirDenseElementsAttrUInt32Get, |
| elementType, arrayInfo)); |
| } else if (arrayInfo.itemsize == 8) { |
| // unsigned i64 |
| MlirType elementType = signless |
| ? mlirIntegerTypeGet(context, 64) |
| : mlirIntegerTypeUnsignedGet(context, 64); |
| return PyDenseElementsAttribute(contextWrapper->getRef(), |
| bulkLoad(context, |
| mlirDenseElementsAttrUInt64Get, |
| elementType, arrayInfo)); |
| } |
| } |
| |
| // TODO: Fall back to string-based get. |
| std::string message = "unimplemented array format conversion from format: "; |
| message.append(arrayInfo.format); |
| throw SetPyError(PyExc_ValueError, message); |
| } |
| |
| static PyDenseElementsAttribute getSplat(PyType shapedType, |
| PyAttribute &elementAttr) { |
| auto contextWrapper = |
| PyMlirContext::forContext(mlirTypeGetContext(shapedType)); |
| if (!mlirAttributeIsAInteger(elementAttr) && |
| !mlirAttributeIsAFloat(elementAttr)) { |
| std::string message = "Illegal element type for DenseElementsAttr: "; |
| message.append(py::repr(py::cast(elementAttr))); |
| throw SetPyError(PyExc_ValueError, message); |
| } |
| if (!mlirTypeIsAShaped(shapedType) || |
| !mlirShapedTypeHasStaticShape(shapedType)) { |
| std::string message = |
| "Expected a static ShapedType for the shaped_type parameter: "; |
| message.append(py::repr(py::cast(shapedType))); |
| throw SetPyError(PyExc_ValueError, message); |
| } |
| MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); |
| MlirType attrType = mlirAttributeGetType(elementAttr); |
| if (!mlirTypeEqual(shapedElementType, attrType)) { |
| std::string message = |
| "Shaped element type and attribute type must be equal: shaped="; |
| message.append(py::repr(py::cast(shapedType))); |
| message.append(", element="); |
| message.append(py::repr(py::cast(elementAttr))); |
| throw SetPyError(PyExc_ValueError, message); |
| } |
| |
| MlirAttribute elements = |
| mlirDenseElementsAttrSplatGet(shapedType, elementAttr); |
| return PyDenseElementsAttribute(contextWrapper->getRef(), elements); |
| } |
| |
| intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } |
| |
| py::buffer_info accessBuffer() { |
| MlirType shapedType = mlirAttributeGetType(*this); |
| MlirType elementType = mlirShapedTypeGetElementType(shapedType); |
| |
| if (mlirTypeIsAF32(elementType)) { |
| // f32 |
| return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); |
| } else if (mlirTypeIsAF64(elementType)) { |
| // f64 |
| return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); |
| } else if (mlirTypeIsAInteger(elementType) && |
| mlirIntegerTypeGetWidth(elementType) == 32) { |
| if (mlirIntegerTypeIsSignless(elementType) || |
| mlirIntegerTypeIsSigned(elementType)) { |
| // i32 |
| return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); |
| } else if (mlirIntegerTypeIsUnsigned(elementType)) { |
| // unsigned i32 |
| return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); |
| } |
| } else if (mlirTypeIsAInteger(elementType) && |
| mlirIntegerTypeGetWidth(elementType) == 64) { |
| if (mlirIntegerTypeIsSignless(elementType) || |
| mlirIntegerTypeIsSigned(elementType)) { |
| // i64 |
| return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); |
| } else if (mlirIntegerTypeIsUnsigned(elementType)) { |
| // unsigned i64 |
| return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); |
| } |
| } |
| |
| std::string message = "unimplemented array format."; |
| throw SetPyError(PyExc_ValueError, message); |
| } |
| |
| static void bindDerived(ClassTy &c) { |
| c.def("__len__", &PyDenseElementsAttribute::dunderLen) |
| .def_static("get", PyDenseElementsAttribute::getFromBuffer, |
| py::arg("array"), py::arg("signless") = true, |
| py::arg("context") = py::none(), |
| "Gets from a buffer or ndarray") |
| .def_static("get_splat", PyDenseElementsAttribute::getSplat, |
| py::arg("shaped_type"), py::arg("element_attr"), |
| "Gets a DenseElementsAttr where all values are the same") |
| .def_property_readonly("is_splat", |
| [](PyDenseElementsAttribute &self) -> bool { |
| return mlirDenseElementsAttrIsSplat(self); |
| }) |
| .def_buffer(&PyDenseElementsAttribute::accessBuffer); |
| } |
| |
| private: |
| template <typename ElementTy> |
| static MlirAttribute |
| bulkLoad(MlirContext context, |
| MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), |
| MlirType mlirElementType, py::buffer_info &arrayInfo) { |
| SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(), |
| arrayInfo.shape.begin() + arrayInfo.ndim); |
| auto shapedType = |
| mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); |
| intptr_t numElements = arrayInfo.size; |
| const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr); |
| return ctor(shapedType, numElements, contents); |
| } |
| |
| static bool isUnsignedIntegerFormat(const std::string &format) { |
| if (format.empty()) |
| return false; |
| char code = format[0]; |
| return code == 'I' || code == 'B' || code == 'H' || code == 'L' || |
| code == 'Q'; |
| } |
| |
| static bool isSignedIntegerFormat(const std::string &format) { |
| if (format.empty()) |
| return false; |
| char code = format[0]; |
| return code == 'i' || code == 'b' || code == 'h' || code == 'l' || |
| code == 'q'; |
| } |
| |
| template <typename Type> |
| py::buffer_info bufferInfo(MlirType shapedType, |
| Type (*value)(MlirAttribute, intptr_t)) { |
| intptr_t rank = mlirShapedTypeGetRank(shapedType); |
| // Prepare the data for the buffer_info. |
| // Buffer is configured for read-only access below. |
| Type *data = static_cast<Type *>( |
| const_cast<void *>(mlirDenseElementsAttrGetRawData(*this))); |
| // Prepare the shape for the buffer_info. |
| SmallVector<intptr_t, 4> shape; |
| for (intptr_t i = 0; i < rank; ++i) |
| shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); |
| // Prepare the strides for the buffer_info. |
| SmallVector<intptr_t, 4> strides; |
| intptr_t strideFactor = 1; |
| for (intptr_t i = 1; i < rank; ++i) { |
| strideFactor = 1; |
| for (intptr_t j = i; j < rank; ++j) { |
| strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); |
| } |
| strides.push_back(sizeof(Type) * strideFactor); |
| } |
| strides.push_back(sizeof(Type)); |
| return py::buffer_info(data, sizeof(Type), |
| py::format_descriptor<Type>::format(), rank, shape, |
| strides, /*readonly=*/true); |
| } |
| }; // namespace |
| |
| /// Refinement of the PyDenseElementsAttribute for attributes containing integer |
| /// (and boolean) values. Supports element access. |
| class PyDenseIntElementsAttribute |
| : public PyConcreteAttribute<PyDenseIntElementsAttribute, |
| PyDenseElementsAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; |
| static constexpr const char *pyClassName = "DenseIntElementsAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| /// Returns the element at the given linear position. Asserts if the index is |
| /// out of range. |
| py::int_ dunderGetItem(intptr_t pos) { |
| if (pos < 0 || pos >= dunderLen()) { |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds element"); |
| } |
| |
| MlirType type = mlirAttributeGetType(*this); |
| type = mlirShapedTypeGetElementType(type); |
| assert(mlirTypeIsAInteger(type) && |
| "expected integer element type in dense int elements attribute"); |
| // Dispatch element extraction to an appropriate C function based on the |
| // elemental type of the attribute. py::int_ is implicitly constructible |
| // from any C++ integral type and handles bitwidth correctly. |
| // TODO: consider caching the type properties in the constructor to avoid |
| // querying them on each element access. |
| unsigned width = mlirIntegerTypeGetWidth(type); |
| bool isUnsigned = mlirIntegerTypeIsUnsigned(type); |
| if (isUnsigned) { |
| if (width == 1) { |
| return mlirDenseElementsAttrGetBoolValue(*this, pos); |
| } |
| if (width == 32) { |
| return mlirDenseElementsAttrGetUInt32Value(*this, pos); |
| } |
| if (width == 64) { |
| return mlirDenseElementsAttrGetUInt64Value(*this, pos); |
| } |
| } else { |
| if (width == 1) { |
| return mlirDenseElementsAttrGetBoolValue(*this, pos); |
| } |
| if (width == 32) { |
| return mlirDenseElementsAttrGetInt32Value(*this, pos); |
| } |
| if (width == 64) { |
| return mlirDenseElementsAttrGetInt64Value(*this, pos); |
| } |
| } |
| throw SetPyError(PyExc_TypeError, "Unsupported integer type"); |
| } |
| |
| static void bindDerived(ClassTy &c) { |
| c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); |
| } |
| }; |
| |
| /// Refinement of PyDenseElementsAttribute for attributes containing |
| /// floating-point values. Supports element access. |
| class PyDenseFPElementsAttribute |
| : public PyConcreteAttribute<PyDenseFPElementsAttribute, |
| PyDenseElementsAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; |
| static constexpr const char *pyClassName = "DenseFPElementsAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| py::float_ dunderGetItem(intptr_t pos) { |
| if (pos < 0 || pos >= dunderLen()) { |
| throw SetPyError(PyExc_IndexError, |
| "attempt to access out of bounds element"); |
| } |
| |
| MlirType type = mlirAttributeGetType(*this); |
| type = mlirShapedTypeGetElementType(type); |
| // Dispatch element extraction to an appropriate C function based on the |
| // elemental type of the attribute. py::float_ is implicitly constructible |
| // from float and double. |
| // TODO: consider caching the type properties in the constructor to avoid |
| // querying them on each element access. |
| if (mlirTypeIsAF32(type)) { |
| return mlirDenseElementsAttrGetFloatValue(*this, pos); |
| } |
| if (mlirTypeIsAF64(type)) { |
| return mlirDenseElementsAttrGetDoubleValue(*this, pos); |
| } |
| throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); |
| } |
| |
| static void bindDerived(ClassTy &c) { |
| c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); |
| } |
| }; |
| |
| class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; |
| static constexpr const char *pyClassName = "TypeAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](PyType value, DefaultingPyMlirContext context) { |
| MlirAttribute attr = mlirTypeAttrGet(value.get()); |
| return PyTypeAttribute(context->getRef(), attr); |
| }, |
| py::arg("value"), py::arg("context") = py::none(), |
| "Gets a uniqued Type attribute"); |
| c.def_property_readonly("value", [](PyTypeAttribute &self) { |
| return PyType(self.getContext()->getRef(), |
| mlirTypeAttrGetValue(self.get())); |
| }); |
| } |
| }; |
| |
| /// Unit Attribute subclass. Unit attributes don't have values. |
| class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; |
| static constexpr const char *pyClassName = "UnitAttr"; |
| using PyConcreteAttribute::PyConcreteAttribute; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](DefaultingPyMlirContext context) { |
| return PyUnitAttribute(context->getRef(), |
| mlirUnitAttrGet(context->get())); |
| }, |
| py::arg("context") = py::none(), "Create a Unit attribute."); |
| } |
| }; |
| |
| } // namespace |
| |
| //------------------------------------------------------------------------------ |
| // Builtin type subclasses. |
| //------------------------------------------------------------------------------ |
| |
| namespace { |
| |
| /// CRTP base classes for Python types that subclass Type and should be |
| /// castable from it (i.e. via something like IntegerType(t)). |
| /// By default, type class hierarchies are one level deep (i.e. a |
| /// concrete type class extends PyType); however, intermediate python-visible |
| /// base classes can be modeled by specifying a BaseTy. |
| template <typename DerivedTy, typename BaseTy = PyType> |
| class PyConcreteType : public BaseTy { |
| public: |
| // Derived classes must define statics for: |
| // IsAFunctionTy isaFunction |
| // const char *pyClassName |
| using ClassTy = py::class_<DerivedTy, BaseTy>; |
| using IsAFunctionTy = bool (*)(MlirType); |
| |
| PyConcreteType() = default; |
| PyConcreteType(PyMlirContextRef contextRef, MlirType t) |
| : BaseTy(std::move(contextRef), t) {} |
| PyConcreteType(PyType &orig) |
| : PyConcreteType(orig.getContext(), castFrom(orig)) {} |
| |
| static MlirType castFrom(PyType &orig) { |
| if (!DerivedTy::isaFunction(orig)) { |
| auto origRepr = py::repr(py::cast(orig)).cast<std::string>(); |
| throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + |
| DerivedTy::pyClassName + |
| " (from " + origRepr + ")"); |
| } |
| return orig; |
| } |
| |
| static void bind(py::module &m) { |
| auto cls = ClassTy(m, DerivedTy::pyClassName); |
| cls.def(py::init<PyType &>(), py::keep_alive<0, 1>()); |
| DerivedTy::bindDerived(cls); |
| } |
| |
| /// Implemented by derived classes to add methods to the Python subclass. |
| static void bindDerived(ClassTy &m) {} |
| }; |
| |
| class PyIntegerType : public PyConcreteType<PyIntegerType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; |
| static constexpr const char *pyClassName = "IntegerType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get_signless", |
| [](unsigned width, DefaultingPyMlirContext context) { |
| MlirType t = mlirIntegerTypeGet(context->get(), width); |
| return PyIntegerType(context->getRef(), t); |
| }, |
| py::arg("width"), py::arg("context") = py::none(), |
| "Create a signless integer type"); |
| c.def_static( |
| "get_signed", |
| [](unsigned width, DefaultingPyMlirContext context) { |
| MlirType t = mlirIntegerTypeSignedGet(context->get(), width); |
| return PyIntegerType(context->getRef(), t); |
| }, |
| py::arg("width"), py::arg("context") = py::none(), |
| "Create a signed integer type"); |
| c.def_static( |
| "get_unsigned", |
| [](unsigned width, DefaultingPyMlirContext context) { |
| MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); |
| return PyIntegerType(context->getRef(), t); |
| }, |
| py::arg("width"), py::arg("context") = py::none(), |
| "Create an unsigned integer type"); |
| c.def_property_readonly( |
| "width", |
| [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, |
| "Returns the width of the integer type"); |
| c.def_property_readonly( |
| "is_signless", |
| [](PyIntegerType &self) -> bool { |
| return mlirIntegerTypeIsSignless(self); |
| }, |
| "Returns whether this is a signless integer"); |
| c.def_property_readonly( |
| "is_signed", |
| [](PyIntegerType &self) -> bool { |
| return mlirIntegerTypeIsSigned(self); |
| }, |
| "Returns whether this is a signed integer"); |
| c.def_property_readonly( |
| "is_unsigned", |
| [](PyIntegerType &self) -> bool { |
| return mlirIntegerTypeIsUnsigned(self); |
| }, |
| "Returns whether this is an unsigned integer"); |
| } |
| }; |
| |
| /// Index Type subclass - IndexType. |
| class PyIndexType : public PyConcreteType<PyIndexType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; |
| static constexpr const char *pyClassName = "IndexType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](DefaultingPyMlirContext context) { |
| MlirType t = mlirIndexTypeGet(context->get()); |
| return PyIndexType(context->getRef(), t); |
| }, |
| py::arg("context") = py::none(), "Create a index type."); |
| } |
| }; |
| |
| /// Floating Point Type subclass - BF16Type. |
| class PyBF16Type : public PyConcreteType<PyBF16Type> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; |
| static constexpr const char *pyClassName = "BF16Type"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](DefaultingPyMlirContext context) { |
| MlirType t = mlirBF16TypeGet(context->get()); |
| return PyBF16Type(context->getRef(), t); |
| }, |
| py::arg("context") = py::none(), "Create a bf16 type."); |
| } |
| }; |
| |
| /// Floating Point Type subclass - F16Type. |
| class PyF16Type : public PyConcreteType<PyF16Type> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; |
| static constexpr const char *pyClassName = "F16Type"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](DefaultingPyMlirContext context) { |
| MlirType t = mlirF16TypeGet(context->get()); |
| return PyF16Type(context->getRef(), t); |
| }, |
| py::arg("context") = py::none(), "Create a f16 type."); |
| } |
| }; |
| |
| /// Floating Point Type subclass - F32Type. |
| class PyF32Type : public PyConcreteType<PyF32Type> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; |
| static constexpr const char *pyClassName = "F32Type"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](DefaultingPyMlirContext context) { |
| MlirType t = mlirF32TypeGet(context->get()); |
| return PyF32Type(context->getRef(), t); |
| }, |
| py::arg("context") = py::none(), "Create a f32 type."); |
| } |
| }; |
| |
| /// Floating Point Type subclass - F64Type. |
| class PyF64Type : public PyConcreteType<PyF64Type> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; |
| static constexpr const char *pyClassName = "F64Type"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](DefaultingPyMlirContext context) { |
| MlirType t = mlirF64TypeGet(context->get()); |
| return PyF64Type(context->getRef(), t); |
| }, |
| py::arg("context") = py::none(), "Create a f64 type."); |
| } |
| }; |
| |
| /// None Type subclass - NoneType. |
| class PyNoneType : public PyConcreteType<PyNoneType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; |
| static constexpr const char *pyClassName = "NoneType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](DefaultingPyMlirContext context) { |
| MlirType t = mlirNoneTypeGet(context->get()); |
| return PyNoneType(context->getRef(), t); |
| }, |
| py::arg("context") = py::none(), "Create a none type."); |
| } |
| }; |
| |
| /// Complex Type subclass - ComplexType. |
| class PyComplexType : public PyConcreteType<PyComplexType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; |
| static constexpr const char *pyClassName = "ComplexType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](PyType &elementType) { |
| // The element must be a floating point or integer scalar type. |
| if (mlirTypeIsAIntegerOrFloat(elementType)) { |
| MlirType t = mlirComplexTypeGet(elementType); |
| return PyComplexType(elementType.getContext(), t); |
| } |
| throw SetPyError( |
| PyExc_ValueError, |
| Twine("invalid '") + |
| py::repr(py::cast(elementType)).cast<std::string>() + |
| "' and expected floating point or integer type."); |
| }, |
| "Create a complex type"); |
| c.def_property_readonly( |
| "element_type", |
| [](PyComplexType &self) -> PyType { |
| MlirType t = mlirComplexTypeGetElementType(self); |
| return PyType(self.getContext(), t); |
| }, |
| "Returns element type."); |
| } |
| }; |
| |
| class PyShapedType : public PyConcreteType<PyShapedType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; |
| static constexpr const char *pyClassName = "ShapedType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_property_readonly( |
| "element_type", |
| [](PyShapedType &self) { |
| MlirType t = mlirShapedTypeGetElementType(self); |
| return PyType(self.getContext(), t); |
| }, |
| "Returns the element type of the shaped type."); |
| c.def_property_readonly( |
| "has_rank", |
| [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, |
| "Returns whether the given shaped type is ranked."); |
| c.def_property_readonly( |
| "rank", |
| [](PyShapedType &self) { |
| self.requireHasRank(); |
| return mlirShapedTypeGetRank(self); |
| }, |
| "Returns the rank of the given ranked shaped type."); |
| c.def_property_readonly( |
| "has_static_shape", |
| [](PyShapedType &self) -> bool { |
| return mlirShapedTypeHasStaticShape(self); |
| }, |
| "Returns whether the given shaped type has a static shape."); |
| c.def( |
| "is_dynamic_dim", |
| [](PyShapedType &self, intptr_t dim) -> bool { |
| self.requireHasRank(); |
| return mlirShapedTypeIsDynamicDim(self, dim); |
| }, |
| "Returns whether the dim-th dimension of the given shaped type is " |
| "dynamic."); |
| c.def( |
| "get_dim_size", |
| [](PyShapedType &self, intptr_t dim) { |
| self.requireHasRank(); |
| return mlirShapedTypeGetDimSize(self, dim); |
| }, |
| "Returns the dim-th dimension of the given ranked shaped type."); |
| c.def_static( |
| "is_dynamic_size", |
| [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, |
| "Returns whether the given dimension size indicates a dynamic " |
| "dimension."); |
| c.def( |
| "is_dynamic_stride_or_offset", |
| [](PyShapedType &self, int64_t val) -> bool { |
| self.requireHasRank(); |
| return mlirShapedTypeIsDynamicStrideOrOffset(val); |
| }, |
| "Returns whether the given value is used as a placeholder for dynamic " |
| "strides and offsets in shaped types."); |
| } |
| |
| private: |
| void requireHasRank() { |
| if (!mlirShapedTypeHasRank(*this)) { |
| throw SetPyError( |
| PyExc_ValueError, |
| "calling this method requires that the type has a rank."); |
| } |
| } |
| }; |
| |
| /// Vector Type subclass - VectorType. |
| class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; |
| static constexpr const char *pyClassName = "VectorType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](std::vector<int64_t> shape, PyType &elementType, |
| DefaultingPyLocation loc) { |
| MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(), |
| elementType, loc); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirTypeIsNull(t)) { |
| throw SetPyError( |
| PyExc_ValueError, |
| Twine("invalid '") + |
| py::repr(py::cast(elementType)).cast<std::string>() + |
| "' and expected floating point or integer type."); |
| } |
| return PyVectorType(elementType.getContext(), t); |
| }, |
| py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), |
| "Create a vector type"); |
| } |
| }; |
| |
| /// Ranked Tensor Type subclass - RankedTensorType. |
| class PyRankedTensorType |
| : public PyConcreteType<PyRankedTensorType, PyShapedType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; |
| static constexpr const char *pyClassName = "RankedTensorType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](std::vector<int64_t> shape, PyType &elementType, |
| DefaultingPyLocation loc) { |
| MlirType t = mlirRankedTensorTypeGetChecked( |
| shape.size(), shape.data(), elementType, loc); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirTypeIsNull(t)) { |
| throw SetPyError( |
| PyExc_ValueError, |
| Twine("invalid '") + |
| py::repr(py::cast(elementType)).cast<std::string>() + |
| "' and expected floating point, integer, vector or " |
| "complex " |
| "type."); |
| } |
| return PyRankedTensorType(elementType.getContext(), t); |
| }, |
| py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), |
| "Create a ranked tensor type"); |
| } |
| }; |
| |
| /// Unranked Tensor Type subclass - UnrankedTensorType. |
| class PyUnrankedTensorType |
| : public PyConcreteType<PyUnrankedTensorType, PyShapedType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; |
| static constexpr const char *pyClassName = "UnrankedTensorType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](PyType &elementType, DefaultingPyLocation loc) { |
| MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirTypeIsNull(t)) { |
| throw SetPyError( |
| PyExc_ValueError, |
| Twine("invalid '") + |
| py::repr(py::cast(elementType)).cast<std::string>() + |
| "' and expected floating point, integer, vector or " |
| "complex " |
| "type."); |
| } |
| return PyUnrankedTensorType(elementType.getContext(), t); |
| }, |
| py::arg("element_type"), py::arg("loc") = py::none(), |
| "Create a unranked tensor type"); |
| } |
| }; |
| |
| /// Ranked MemRef Type subclass - MemRefType. |
| class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; |
| static constexpr const char *pyClassName = "MemRefType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding |
| // once the affine map binding is completed. |
| c.def_static( |
| "get_contiguous_memref", |
| // TODO: Make the location optional and create a default location. |
| [](PyType &elementType, std::vector<int64_t> shape, |
| unsigned memorySpace, DefaultingPyLocation loc) { |
| MlirType t = mlirMemRefTypeContiguousGetChecked( |
| elementType, shape.size(), shape.data(), memorySpace, loc); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirTypeIsNull(t)) { |
| throw SetPyError( |
| PyExc_ValueError, |
| Twine("invalid '") + |
| py::repr(py::cast(elementType)).cast<std::string>() + |
| "' and expected floating point, integer, vector or " |
| "complex " |
| "type."); |
| } |
| return PyMemRefType(elementType.getContext(), t); |
| }, |
| py::arg("element_type"), py::arg("shape"), py::arg("memory_space"), |
| py::arg("loc") = py::none(), "Create a memref type") |
| .def_property_readonly( |
| "num_affine_maps", |
| [](PyMemRefType &self) -> intptr_t { |
| return mlirMemRefTypeGetNumAffineMaps(self); |
| }, |
| "Returns the number of affine layout maps in the given MemRef " |
| "type.") |
| .def_property_readonly( |
| "memory_space", |
| [](PyMemRefType &self) -> unsigned { |
| return mlirMemRefTypeGetMemorySpace(self); |
| }, |
| "Returns the memory space of the given MemRef type."); |
| } |
| }; |
| |
| /// Unranked MemRef Type subclass - UnrankedMemRefType. |
| class PyUnrankedMemRefType |
| : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; |
| static constexpr const char *pyClassName = "UnrankedMemRefType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](PyType &elementType, unsigned memorySpace, |
| DefaultingPyLocation loc) { |
| MlirType t = |
| mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirTypeIsNull(t)) { |
| throw SetPyError( |
| PyExc_ValueError, |
| Twine("invalid '") + |
| py::repr(py::cast(elementType)).cast<std::string>() + |
| "' and expected floating point, integer, vector or " |
| "complex " |
| "type."); |
| } |
| return PyUnrankedMemRefType(elementType.getContext(), t); |
| }, |
| py::arg("element_type"), py::arg("memory_space"), |
| py::arg("loc") = py::none(), "Create a unranked memref type") |
| .def_property_readonly( |
| "memory_space", |
| [](PyUnrankedMemRefType &self) -> unsigned { |
| return mlirUnrankedMemrefGetMemorySpace(self); |
| }, |
| "Returns the memory space of the given Unranked MemRef type."); |
| } |
| }; |
| |
| /// Tuple Type subclass - TupleType. |
| class PyTupleType : public PyConcreteType<PyTupleType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; |
| static constexpr const char *pyClassName = "TupleType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get_tuple", |
| [](py::list elementList, DefaultingPyMlirContext context) { |
| intptr_t num = py::len(elementList); |
| // Mapping py::list to SmallVector. |
| SmallVector<MlirType, 4> elements; |
| for (auto element : elementList) |
| elements.push_back(element.cast<PyType>()); |
| MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); |
| return PyTupleType(context->getRef(), t); |
| }, |
| py::arg("elements"), py::arg("context") = py::none(), |
| "Create a tuple type"); |
| c.def( |
| "get_type", |
| [](PyTupleType &self, intptr_t pos) -> PyType { |
| MlirType t = mlirTupleTypeGetType(self, pos); |
| return PyType(self.getContext(), t); |
| }, |
| "Returns the pos-th type in the tuple type."); |
| c.def_property_readonly( |
| "num_types", |
| [](PyTupleType &self) -> intptr_t { |
| return mlirTupleTypeGetNumTypes(self); |
| }, |
| "Returns the number of types contained in a tuple."); |
| } |
| }; |
| |
| /// Function type. |
| class PyFunctionType : public PyConcreteType<PyFunctionType> { |
| public: |
| static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; |
| static constexpr const char *pyClassName = "FunctionType"; |
| using PyConcreteType::PyConcreteType; |
| |
| static void bindDerived(ClassTy &c) { |
| c.def_static( |
| "get", |
| [](std::vector<PyType> inputs, std::vector<PyType> results, |
| DefaultingPyMlirContext context) { |
| SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end()); |
| SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end()); |
| MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), |
| inputsRaw.data(), resultsRaw.size(), |
| resultsRaw.data()); |
| return PyFunctionType(context->getRef(), t); |
| }, |
| py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), |
| "Gets a FunctionType from a list of input and result types"); |
| c.def_property_readonly( |
| "inputs", |
| [](PyFunctionType &self) { |
| MlirType t = self; |
| auto contextRef = self.getContext(); |
| py::list types; |
| for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; |
| ++i) { |
| types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); |
| } |
| return types; |
| }, |
| "Returns the list of input types in the FunctionType."); |
| c.def_property_readonly( |
| "results", |
| [](PyFunctionType &self) { |
| auto contextRef = self.getContext(); |
| py::list types; |
| for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; |
| ++i) { |
| types.append( |
| PyType(contextRef, mlirFunctionTypeGetResult(self, i))); |
| } |
| return types; |
| }, |
| "Returns the list of result types in the FunctionType."); |
| } |
| }; |
| |
| } // namespace |
| |
| //------------------------------------------------------------------------------ |
| // Populates the pybind11 IR submodule. |
| //------------------------------------------------------------------------------ |
| |
| void mlir::python::populateIRSubmodule(py::module &m) { |
| //---------------------------------------------------------------------------- |
| // Mapping of MlirContext |
| //---------------------------------------------------------------------------- |
| py::class_<PyMlirContext>(m, "Context") |
| .def(py::init<>(&PyMlirContext::createNewContextForInit)) |
| .def_static("_get_live_count", &PyMlirContext::getLiveCount) |
| .def("_get_context_again", |
| [](PyMlirContext &self) { |
| PyMlirContextRef ref = PyMlirContext::forContext(self.get()); |
| return ref.releaseObject(); |
| }) |
| .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) |
| .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) |
| .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
| &PyMlirContext::getCapsule) |
| .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) |
| .def("__enter__", &PyMlirContext::contextEnter) |
| .def("__exit__", &PyMlirContext::contextExit) |
| .def_property_readonly_static( |
| "current", |
| [](py::object & /*class*/) { |
| auto *context = PyThreadContextEntry::getDefaultContext(); |
| if (!context) |
| throw SetPyError(PyExc_ValueError, "No current Context"); |
| return context; |
| }, |
| "Gets the Context bound to the current thread or raises ValueError") |
| .def_property_readonly( |
| "dialects", |
| [](PyMlirContext &self) { return PyDialects(self.getRef()); }, |
| "Gets a container for accessing dialects by name") |
| .def_property_readonly( |
| "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, |
| "Alias for 'dialect'") |
| .def( |
| "get_dialect_descriptor", |
| [=](PyMlirContext &self, std::string &name) { |
| MlirDialect dialect = mlirContextGetOrLoadDialect( |
| self.get(), {name.data(), name.size()}); |
| if (mlirDialectIsNull(dialect)) { |
| throw SetPyError(PyExc_ValueError, |
| Twine("Dialect '") + name + "' not found"); |
| } |
| return PyDialectDescriptor(self.getRef(), dialect); |
| }, |
| "Gets or loads a dialect by name, returning its descriptor object") |
| .def_property( |
| "allow_unregistered_dialects", |
| [](PyMlirContext &self) -> bool { |
| return mlirContextGetAllowUnregisteredDialects(self.get()); |
| }, |
| [](PyMlirContext &self, bool value) { |
| mlirContextSetAllowUnregisteredDialects(self.get(), value); |
| }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyDialectDescriptor |
| //---------------------------------------------------------------------------- |
| py::class_<PyDialectDescriptor>(m, "DialectDescriptor") |
| .def_property_readonly("namespace", |
| [](PyDialectDescriptor &self) { |
| MlirStringRef ns = |
| mlirDialectGetNamespace(self.get()); |
| return py::str(ns.data, ns.length); |
| }) |
| .def("__repr__", [](PyDialectDescriptor &self) { |
| MlirStringRef ns = mlirDialectGetNamespace(self.get()); |
| std::string repr("<DialectDescriptor "); |
| repr.append(ns.data, ns.length); |
| repr.append(">"); |
| return repr; |
| }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyDialects |
| //---------------------------------------------------------------------------- |
| py::class_<PyDialects>(m, "Dialects") |
| .def("__getitem__", |
| [=](PyDialects &self, std::string keyName) { |
| MlirDialect dialect = |
| self.getDialectForKey(keyName, /*attrError=*/false); |
| py::object descriptor = |
| py::cast(PyDialectDescriptor{self.getContext(), dialect}); |
| return createCustomDialectWrapper(keyName, std::move(descriptor)); |
| }) |
| .def("__getattr__", [=](PyDialects &self, std::string attrName) { |
| MlirDialect dialect = |
| self.getDialectForKey(attrName, /*attrError=*/true); |
| py::object descriptor = |
| py::cast(PyDialectDescriptor{self.getContext(), dialect}); |
| return createCustomDialectWrapper(attrName, std::move(descriptor)); |
| }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyDialect |
| //---------------------------------------------------------------------------- |
| py::class_<PyDialect>(m, "Dialect") |
| .def(py::init<py::object>(), "descriptor") |
| .def_property_readonly( |
| "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) |
| .def("__repr__", [](py::object self) { |
| auto clazz = self.attr("__class__"); |
| return py::str("<Dialect ") + |
| self.attr("descriptor").attr("namespace") + py::str(" (class ") + |
| clazz.attr("__module__") + py::str(".") + |
| clazz.attr("__name__") + py::str(")>"); |
| }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of Location |
| //---------------------------------------------------------------------------- |
| py::class_<PyLocation>(m, "Location") |
| .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) |
| .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) |
| .def("__enter__", &PyLocation::contextEnter) |
| .def("__exit__", &PyLocation::contextExit) |
| .def("__eq__", |
| [](PyLocation &self, PyLocation &other) -> bool { |
| return mlirLocationEqual(self, other); |
| }) |
| .def("__eq__", [](PyLocation &self, py::object other) { return false; }) |
| .def_property_readonly_static( |
| "current", |
| [](py::object & /*class*/) { |
| auto *loc = PyThreadContextEntry::getDefaultLocation(); |
| if (!loc) |
| throw SetPyError(PyExc_ValueError, "No current Location"); |
| return loc; |
| }, |
| "Gets the Location bound to the current thread or raises ValueError") |
| .def_static( |
| "unknown", |
| [](DefaultingPyMlirContext context) { |
| return PyLocation(context->getRef(), |
| mlirLocationUnknownGet(context->get())); |
| }, |
| py::arg("context") = py::none(), |
| "Gets a Location representing an unknown location") |
| .def_static( |
| "file", |
| [](std::string filename, int line, int col, |
| DefaultingPyMlirContext context) { |
| return PyLocation( |
| context->getRef(), |
| mlirLocationFileLineColGet( |
| context->get(), toMlirStringRef(filename), line, col)); |
| }, |
| py::arg("filename"), py::arg("line"), py::arg("col"), |
| py::arg("context") = py::none(), kContextGetFileLocationDocstring) |
| .def_property_readonly( |
| "context", |
| [](PyLocation &self) { return self.getContext().getObject(); }, |
| "Context that owns the Location") |
| .def("__repr__", [](PyLocation &self) { |
| PyPrintAccumulator printAccum; |
| mlirLocationPrint(self, printAccum.getCallback(), |
| printAccum.getUserData()); |
| return printAccum.join(); |
| }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of Module |
| //---------------------------------------------------------------------------- |
| py::class_<PyModule>(m, "Module") |
| .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) |
| .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) |
| .def_static( |
| "parse", |
| [](const std::string moduleAsm, DefaultingPyMlirContext context) { |
| MlirModule module = mlirModuleCreateParse( |
| context->get(), toMlirStringRef(moduleAsm)); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirModuleIsNull(module)) { |
| throw SetPyError( |
| PyExc_ValueError, |
| "Unable to parse module assembly (see diagnostics)"); |
| } |
| return PyModule::forModule(module).releaseObject(); |
| }, |
| py::arg("asm"), py::arg("context") = py::none(), |
| kModuleParseDocstring) |
| .def_static( |
| "create", |
| [](DefaultingPyLocation loc) { |
| MlirModule module = mlirModuleCreateEmpty(loc); |
| return PyModule::forModule(module).releaseObject(); |
| }, |
| py::arg("loc") = py::none(), "Creates an empty module") |
| .def_property_readonly( |
| "context", |
| [](PyModule &self) { return self.getContext().getObject(); }, |
| "Context that created the Module") |
| .def_property_readonly( |
| "operation", |
| [](PyModule &self) { |
| return PyOperation::forOperation(self.getContext(), |
| mlirModuleGetOperation(self.get()), |
| self.getRef().releaseObject()) |
| .releaseObject(); |
| }, |
| "Accesses the module as an operation") |
| .def_property_readonly( |
| "body", |
| [](PyModule &self) { |
| PyOperationRef module_op = PyOperation::forOperation( |
| self.getContext(), mlirModuleGetOperation(self.get()), |
| self.getRef().releaseObject()); |
| PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); |
| return returnBlock; |
| }, |
| "Return the block for this module") |
| .def( |
| "dump", |
| [](PyModule &self) { |
| mlirOperationDump(mlirModuleGetOperation(self.get())); |
| }, |
| kDumpDocstring) |
| .def( |
| "__str__", |
| [](PyModule &self) { |
| MlirOperation operation = mlirModuleGetOperation(self.get()); |
| PyPrintAccumulator printAccum; |
| mlirOperationPrint(operation, printAccum.getCallback(), |
| printAccum.getUserData()); |
| return printAccum.join(); |
| }, |
| kOperationStrDunderDocstring); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of Operation. |
| //---------------------------------------------------------------------------- |
| py::class_<PyOperationBase>(m, "_OperationBase") |
| .def("__eq__", |
| [](PyOperationBase &self, PyOperationBase &other) { |
| return &self.getOperation() == &other.getOperation(); |
| }) |
| .def("__eq__", |
| [](PyOperationBase &self, py::object other) { return false; }) |
| .def_property_readonly("attributes", |
| [](PyOperationBase &self) { |
| return PyOpAttributeMap( |
| self.getOperation().getRef()); |
| }) |
| .def_property_readonly("operands", |
| [](PyOperationBase &self) { |
| return PyOpOperandList( |
| self.getOperation().getRef()); |
| }) |
| .def_property_readonly("regions", |
| [](PyOperationBase &self) { |
| return PyRegionList( |
| self.getOperation().getRef()); |
| }) |
| .def_property_readonly( |
| "results", |
| [](PyOperationBase &self) { |
| return PyOpResultList(self.getOperation().getRef()); |
| }, |
| "Returns the list of Operation results.") |
| .def_property_readonly( |
| "result", |
| [](PyOperationBase &self) { |
| auto &operation = self.getOperation(); |
| auto numResults = mlirOperationGetNumResults(operation); |
| if (numResults != 1) { |
| auto name = mlirIdentifierStr(mlirOperationGetName(operation)); |
| throw SetPyError( |
| PyExc_ValueError, |
| Twine("Cannot call .result on operation ") + |
| StringRef(name.data, name.length) + " which has " + |
| Twine(numResults) + |
| " results (it is only valid for operations with a " |
| "single result)"); |
| } |
| return PyOpResult(operation.getRef(), |
| mlirOperationGetResult(operation, 0)); |
| }, |
| "Shortcut to get an op result if it has only one (throws an error " |
| "otherwise).") |
| .def("__iter__", |
| [](PyOperationBase &self) { |
| return PyRegionIterator(self.getOperation().getRef()); |
| }) |
| .def( |
| "__str__", |
| [](PyOperationBase &self) { |
| return self.getAsm(/*binary=*/false, |
| /*largeElementsLimit=*/llvm::None, |
| /*enableDebugInfo=*/false, |
| /*prettyDebugInfo=*/false, |
| /*printGenericOpForm=*/false, |
| /*useLocalScope=*/false); |
| }, |
| "Returns the assembly form of the operation.") |
| .def("print", &PyOperationBase::print, |
| // Careful: Lots of arguments must match up with print method. |
| py::arg("file") = py::none(), py::arg("binary") = false, |
| py::arg("large_elements_limit") = py::none(), |
| py::arg("enable_debug_info") = false, |
| py::arg("pretty_debug_info") = false, |
| py::arg("print_generic_op_form") = false, |
| py::arg("use_local_scope") = false, kOperationPrintDocstring) |
| .def("get_asm", &PyOperationBase::getAsm, |
| // Careful: Lots of arguments must match up with get_asm method. |
| py::arg("binary") = false, |
| py::arg("large_elements_limit") = py::none(), |
| py::arg("enable_debug_info") = false, |
| py::arg("pretty_debug_info") = false, |
| py::arg("print_generic_op_form") = false, |
| py::arg("use_local_scope") = false, kOperationGetAsmDocstring); |
| |
| py::class_<PyOperation, PyOperationBase>(m, "Operation") |
| .def_static("create", &PyOperation::create, py::arg("name"), |
| py::arg("operands") = py::none(), |
| py::arg("results") = py::none(), |
| py::arg("attributes") = py::none(), |
| py::arg("successors") = py::none(), py::arg("regions") = 0, |
| py::arg("loc") = py::none(), py::arg("ip") = py::none(), |
| kOperationCreateDocstring) |
| .def_property_readonly( |
| "context", |
| [](PyOperation &self) { return self.getContext().getObject(); }, |
| "Context that owns the Operation") |
| .def_property_readonly("opview", &PyOperation::createOpView); |
| |
| py::class_<PyOpView, PyOperationBase>(m, "OpView") |
| .def(py::init<py::object>()) |
| .def_property_readonly("operation", &PyOpView::getOperationObject) |
| .def_property_readonly( |
| "context", |
| [](PyOpView &self) { |
| return self.getOperation().getContext().getObject(); |
| }, |
| "Context that owns the Operation") |
| .def("__str__", |
| [](PyOpView &self) { return py::str(self.getOperationObject()); }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyRegion. |
| //---------------------------------------------------------------------------- |
| py::class_<PyRegion>(m, "Region") |
| .def_property_readonly( |
| "blocks", |
| [](PyRegion &self) { |
| return PyBlockList(self.getParentOperation(), self.get()); |
| }, |
| "Returns a forward-optimized sequence of blocks.") |
| .def( |
| "__iter__", |
| [](PyRegion &self) { |
| self.checkValid(); |
| MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get()); |
| return PyBlockIterator(self.getParentOperation(), firstBlock); |
| }, |
| "Iterates over blocks in the region.") |
| .def("__eq__", |
| [](PyRegion &self, PyRegion &other) { |
| return self.get().ptr == other.get().ptr; |
| }) |
| .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyBlock. |
| //---------------------------------------------------------------------------- |
| py::class_<PyBlock>(m, "Block") |
| .def_property_readonly( |
| "arguments", |
| [](PyBlock &self) { |
| return PyBlockArgumentList(self.getParentOperation(), self.get()); |
| }, |
| "Returns a list of block arguments.") |
| .def_property_readonly( |
| "operations", |
| [](PyBlock &self) { |
| return PyOperationList(self.getParentOperation(), self.get()); |
| }, |
| "Returns a forward-optimized sequence of operations.") |
| .def( |
| "__iter__", |
| [](PyBlock &self) { |
| self.checkValid(); |
| MlirOperation firstOperation = |
| mlirBlockGetFirstOperation(self.get()); |
| return PyOperationIterator(self.getParentOperation(), |
| firstOperation); |
| }, |
| "Iterates over operations in the block.") |
| .def("__eq__", |
| [](PyBlock &self, PyBlock &other) { |
| return self.get().ptr == other.get().ptr; |
| }) |
| .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) |
| .def( |
| "__str__", |
| [](PyBlock &self) { |
| self.checkValid(); |
| PyPrintAccumulator printAccum; |
| mlirBlockPrint(self.get(), printAccum.getCallback(), |
| printAccum.getUserData()); |
| return printAccum.join(); |
| }, |
| "Returns the assembly form of the block."); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyInsertionPoint. |
| //---------------------------------------------------------------------------- |
| |
| py::class_<PyInsertionPoint>(m, "InsertionPoint") |
| .def(py::init<PyBlock &>(), py::arg("block"), |
| "Inserts after the last operation but still inside the block.") |
| .def("__enter__", &PyInsertionPoint::contextEnter) |
| .def("__exit__", &PyInsertionPoint::contextExit) |
| .def_property_readonly_static( |
| "current", |
| [](py::object & /*class*/) { |
| auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); |
| if (!ip) |
| throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); |
| return ip; |
| }, |
| "Gets the InsertionPoint bound to the current thread or raises " |
| "ValueError if none has been set") |
| .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"), |
| "Inserts before a referenced operation.") |
| .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, |
| py::arg("block"), "Inserts at the beginning of the block.") |
| .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, |
| py::arg("block"), "Inserts before the block terminator.") |
| .def("insert", &PyInsertionPoint::insert, py::arg("operation"), |
| "Inserts an operation."); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyAttribute. |
| //---------------------------------------------------------------------------- |
| py::class_<PyAttribute>(m, "Attribute") |
| .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, |
| &PyAttribute::getCapsule) |
| .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) |
| .def_static( |
| "parse", |
| [](std::string attrSpec, DefaultingPyMlirContext context) { |
| MlirAttribute type = mlirAttributeParseGet( |
| context->get(), toMlirStringRef(attrSpec)); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirAttributeIsNull(type)) { |
| throw SetPyError(PyExc_ValueError, |
| Twine("Unable to parse attribute: '") + |
| attrSpec + "'"); |
| } |
| return PyAttribute(context->getRef(), type); |
| }, |
| py::arg("asm"), py::arg("context") = py::none(), |
| "Parses an attribute from an assembly form") |
| .def_property_readonly( |
| "context", |
| [](PyAttribute &self) { return self.getContext().getObject(); }, |
| "Context that owns the Attribute") |
| .def_property_readonly("type", |
| [](PyAttribute &self) { |
| return PyType(self.getContext()->getRef(), |
| mlirAttributeGetType(self)); |
| }) |
| .def( |
| "get_named", |
| [](PyAttribute &self, std::string name) { |
| return PyNamedAttribute(self, std::move(name)); |
| }, |
| py::keep_alive<0, 1>(), "Binds a name to the attribute") |
| .def("__eq__", |
| [](PyAttribute &self, PyAttribute &other) { return self == other; }) |
| .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) |
| .def( |
| "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, |
| kDumpDocstring) |
| .def( |
| "__str__", |
| [](PyAttribute &self) { |
| PyPrintAccumulator printAccum; |
| mlirAttributePrint(self, printAccum.getCallback(), |
| printAccum.getUserData()); |
| return printAccum.join(); |
| }, |
| "Returns the assembly form of the Attribute.") |
| .def("__repr__", [](PyAttribute &self) { |
| // Generally, assembly formats are not printed for __repr__ because |
| // this can cause exceptionally long debug output and exceptions. |
| // However, attribute values are generally considered useful and are |
| // printed. This may need to be re-evaluated if debug dumps end up |
| // being excessive. |
| PyPrintAccumulator printAccum; |
| printAccum.parts.append("Attribute("); |
| mlirAttributePrint(self, printAccum.getCallback(), |
| printAccum.getUserData()); |
| printAccum.parts.append(")"); |
| return printAccum.join(); |
| }); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyNamedAttribute |
| //---------------------------------------------------------------------------- |
| py::class_<PyNamedAttribute>(m, "NamedAttribute") |
| .def("__repr__", |
| [](PyNamedAttribute &self) { |
| PyPrintAccumulator printAccum; |
| printAccum.parts.append("NamedAttribute("); |
| printAccum.parts.append(self.namedAttr.name.data); |
| printAccum.parts.append("="); |
| mlirAttributePrint(self.namedAttr.attribute, |
| printAccum.getCallback(), |
| printAccum.getUserData()); |
| printAccum.parts.append(")"); |
| return printAccum.join(); |
| }) |
| .def_property_readonly( |
| "name", |
| [](PyNamedAttribute &self) { |
| return py::str(self.namedAttr.name.data, |
| self.namedAttr.name.length); |
| }, |
| "The name of the NamedAttribute binding") |
| .def_property_readonly( |
| "attr", |
| [](PyNamedAttribute &self) { |
| // TODO: When named attribute is removed/refactored, also remove |
| // this constructor (it does an inefficient table lookup). |
| auto contextRef = PyMlirContext::forContext( |
| mlirAttributeGetContext(self.namedAttr.attribute)); |
| return PyAttribute(std::move(contextRef), self.namedAttr.attribute); |
| }, |
| py::keep_alive<0, 1>(), |
| "The underlying generic attribute of the NamedAttribute binding"); |
| |
| // Builtin attribute bindings. |
| PyFloatAttribute::bind(m); |
| PyIntegerAttribute::bind(m); |
| PyBoolAttribute::bind(m); |
| PyStringAttribute::bind(m); |
| PyDenseElementsAttribute::bind(m); |
| PyDenseIntElementsAttribute::bind(m); |
| PyDenseFPElementsAttribute::bind(m); |
| PyTypeAttribute::bind(m); |
| PyUnitAttribute::bind(m); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of PyType. |
| //---------------------------------------------------------------------------- |
| py::class_<PyType>(m, "Type") |
| .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) |
| .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) |
| .def_static( |
| "parse", |
| [](std::string typeSpec, DefaultingPyMlirContext context) { |
| MlirType type = |
| mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); |
| // TODO: Rework error reporting once diagnostic engine is exposed |
| // in C API. |
| if (mlirTypeIsNull(type)) { |
| throw SetPyError(PyExc_ValueError, |
| Twine("Unable to parse type: '") + typeSpec + |
| "'"); |
| } |
| return PyType(context->getRef(), type); |
| }, |
| py::arg("asm"), py::arg("context") = py::none(), |
| kContextParseTypeDocstring) |
| .def_property_readonly( |
| "context", [](PyType &self) { return self.getContext().getObject(); }, |
| "Context that owns the Type") |
| .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) |
| .def("__eq__", [](PyType &self, py::object &other) { return false; }) |
| .def( |
| "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) |
| .def( |
| "__str__", |
| [](PyType &self) { |
| PyPrintAccumulator printAccum; |
| mlirTypePrint(self, printAccum.getCallback(), |
| printAccum.getUserData()); |
| return printAccum.join(); |
| }, |
| "Returns the assembly form of the type.") |
| .def("__repr__", [](PyType &self) { |
| // Generally, assembly formats are not printed for __repr__ because |
| // this can cause exceptionally long debug output and exceptions. |
| // However, types are an exception as they typically have compact |
| // assembly forms and printing them is useful. |
| PyPrintAccumulator printAccum; |
| printAccum.parts.append("Type("); |
| mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); |
| printAccum.parts.append(")"); |
| return printAccum.join(); |
| }); |
| |
| // Builtin type bindings. |
| PyIntegerType::bind(m); |
| PyIndexType::bind(m); |
| PyBF16Type::bind(m); |
| PyF16Type::bind(m); |
| PyF32Type::bind(m); |
| PyF64Type::bind(m); |
| PyNoneType::bind(m); |
| PyComplexType::bind(m); |
| PyShapedType::bind(m); |
| PyVectorType::bind(m); |
| PyRankedTensorType::bind(m); |
| PyUnrankedTensorType::bind(m); |
| PyMemRefType::bind(m); |
| PyUnrankedMemRefType::bind(m); |
| PyTupleType::bind(m); |
| PyFunctionType::bind(m); |
| |
| //---------------------------------------------------------------------------- |
| // Mapping of Value. |
| //---------------------------------------------------------------------------- |
| py::class_<PyValue>(m, "Value") |
| .def_property_readonly( |
| "context", |
| [](PyValue &self) { return self.getParentOperation()->getContext(); }, |
| "Context in which the value lives.") |
| .def( |
| "dump", [](PyValue &self) { mlirValueDump(self.get()); }, |
| kDumpDocstring) |
| .def("__eq__", |
| [](PyValue &self, PyValue &other) { |
| return self.get().ptr == other.get().ptr; |
| }) |
| .def("__eq__", [](PyValue &self, py::object other) { return false; }) |
| .def( |
| "__str__", |
| [](PyValue &self) { |
| PyPrintAccumulator printAccum; |
| printAccum.parts.append("Value("); |
| mlirValuePrint(self.get(), printAccum.getCallback(), |
| printAccum.getUserData()); |
| printAccum.parts.append(")"); |
| return printAccum.join(); |
| }, |
| kValueDunderStrDocstring) |
| .def_property_readonly("type", [](PyValue &self) { |
| return PyType(self.getParentOperation()->getContext(), |
| mlirValueGetType(self.get())); |
| }); |
| PyBlockArgument::bind(m); |
| PyOpResult::bind(m); |
| |
| // Container bindings. |
| PyBlockArgumentList::bind(m); |
| PyBlockIterator::bind(m); |
| PyBlockList::bind(m); |
| PyOperationIterator::bind(m); |
| PyOperationList::bind(m); |
| PyOpAttributeMap::bind(m); |
| PyOpOperandList::bind(m); |
| PyOpResultList::bind(m); |
| PyRegionIterator::bind(m); |
| PyRegionList::bind(m); |
| } |