| //===- AsyncRefCounting.cpp - Implementation of Async Ref Counting --------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements automatic reference counting for Async dialect data |
| // types. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "PassDetail.h" |
| #include "mlir/Analysis/Liveness.h" |
| #include "mlir/Dialect/Async/IR/Async.h" |
| #include "mlir/Dialect/Async/Passes.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/PatternMatch.h" |
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| #include "llvm/ADT/SmallSet.h" |
| |
| using namespace mlir; |
| using namespace mlir::async; |
| |
| #define DEBUG_TYPE "async-ref-counting" |
| |
| namespace { |
| |
| class AsyncRefCountingPass : public AsyncRefCountingBase<AsyncRefCountingPass> { |
| public: |
| AsyncRefCountingPass() = default; |
| void runOnFunction() override; |
| |
| private: |
| /// Adds an automatic reference counting to the `value`. |
| /// |
| /// All values are semantically created with a reference count of +1 and it is |
| /// the responsibility of the last async value user to drop reference count. |
| /// |
| /// Async values created when: |
| /// 1. Operation returns async result (e.g. the result of an |
| /// `async.execute`). |
| /// 2. Async value passed in as a block argument. |
| /// |
| /// To implement automatic reference counting, we must insert a +1 reference |
| /// before each `async.execute` operation using the value, and drop it after |
| /// the last use inside the async body region (we currently drop the reference |
| /// before the `async.yield` terminator). |
| /// |
| /// Automatic reference counting algorithm outline: |
| /// |
| /// 1. `ReturnLike` operations forward the reference counted values without |
| /// modifying the reference count. |
| /// |
| /// 2. Use liveness analysis to find blocks in the CFG where the lifetime of |
| /// reference counted values ends, and insert `drop_ref` operations after |
| /// the last use of the value. |
| /// |
| /// 3. Insert `add_ref` before the `async.execute` operation capturing the |
| /// value, and pairing `drop_ref` before the async body region terminator, |
| /// to release the captured reference counted value when execution |
| /// completes. |
| /// |
| /// 4. If the reference counted value is passed only to some of the block |
| /// successors, insert `drop_ref` operations in the beginning of the blocks |
| /// that do not have reference counted value uses. |
| /// |
| /// |
| /// Example: |
| /// |
| /// %token = ... |
| /// async.execute { |
| /// async.await %token : !async.token // await #1 |
| /// async.yield |
| /// } |
| /// async.await %token : !async.token // await #2 |
| /// |
| /// Based on the liveness analysis await #2 is the last use of the %token, |
| /// however the execution of the async region can be delayed, and to guarantee |
| /// that the %token is still alive when await #1 executes we need to |
| /// explicitly extend its lifetime using `add_ref` operation. |
| /// |
| /// After automatic reference counting: |
| /// |
| /// %token = ... |
| /// |
| /// // Make sure that %token is alive inside async.execute. |
| /// async.add_ref %token {count = 1 : i32} : !async.token |
| /// |
| /// async.execute { |
| /// async.await %token : !async.token // await #1 |
| /// |
| /// // Drop the extra reference added to keep %token alive. |
| /// async.drop_ref %token {count = 1 : i32} : !async.token |
| /// |
| /// async.yied |
| /// } |
| /// async.await %token : !async.token // await #2 |
| /// |
| /// // Drop the reference after the last use of %token. |
| /// async.drop_ref %token {count = 1 : i32} : !async.token |
| /// |
| LogicalResult addAutomaticRefCounting(Value value); |
| }; |
| |
| } // namespace |
| |
| LogicalResult AsyncRefCountingPass::addAutomaticRefCounting(Value value) { |
| MLIRContext *ctx = value.getContext(); |
| OpBuilder builder(ctx); |
| |
| // Set inserton point after the operation producing a value, or at the |
| // beginning of the block if the value defined by the block argument. |
| if (Operation *op = value.getDefiningOp()) |
| builder.setInsertionPointAfter(op); |
| else |
| builder.setInsertionPointToStart(value.getParentBlock()); |
| |
| Location loc = value.getLoc(); |
| auto i32 = IntegerType::get(32, ctx); |
| |
| // Drop the reference count immediately if the value has no uses. |
| if (value.getUses().empty()) { |
| builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1)); |
| return success(); |
| } |
| |
| // Use liveness analysis to find the placement of `drop_ref`operation. |
| auto liveness = getAnalysis<Liveness>(); |
| |
| // We analyse only the blocks of the region that defines the `value`, and do |
| // not check nested blocks attached to operations. |
| // |
| // By analyzing only the `definingRegion` CFG we potentially loose an |
| // opportunity to drop the reference count earlier and can extend the lifetime |
| // of reference counted value longer then it is really required. |
| // |
| // We also assume that all nested regions finish their execution before the |
| // completion of the owner operation. The only exception to this rule is |
| // `async.execute` operation, which is handled explicitly below. |
| Region *definingRegion = value.getParentRegion(); |
| |
| // ------------------------------------------------------------------------ // |
| // Find blocks where the `value` dies: the value is in `liveIn` set and not |
| // in the `liveOut` set. We place `drop_ref` immediately after the last use |
| // of the `value` in such regions. |
| // ------------------------------------------------------------------------ // |
| |
| // Last users of the `value` inside all blocks where the value dies. |
| llvm::SmallSet<Operation *, 4> lastUsers; |
| |
| for (Block &block : definingRegion->getBlocks()) { |
| const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); |
| |
| // Value in live input set or was defined in the block. |
| bool liveIn = blockLiveness->isLiveIn(value) || |
| blockLiveness->getBlock() == value.getParentBlock(); |
| if (!liveIn) |
| continue; |
| |
| // Value is in the live out set. |
| bool liveOut = blockLiveness->isLiveOut(value); |
| if (liveOut) |
| continue; |
| |
| // We proved that `value` dies in the `block`. Now find the last use of the |
| // `value` inside the `block`. |
| |
| // Find any user of the `value` inside the block (including uses in nested |
| // regions attached to the operations in the block). |
| Operation *userInTheBlock = nullptr; |
| for (Operation *user : value.getUsers()) { |
| userInTheBlock = block.findAncestorOpInBlock(*user); |
| if (userInTheBlock) |
| break; |
| } |
| |
| // Values with zero users handled explicitly in the beginning, if the value |
| // is in live out set it must have at least one use in the block. |
| assert(userInTheBlock && "value must have a user in the block"); |
| |
| // Find the last user of the `value` in the block; |
| Operation *lastUser = blockLiveness->getEndOperation(value, userInTheBlock); |
| assert(lastUsers.count(lastUser) == 0 && "last users must be unique"); |
| lastUsers.insert(lastUser); |
| } |
| |
| // Process all the last users of the `value` inside each block where the value |
| // dies. |
| for (Operation *lastUser : lastUsers) { |
| // Return like operations forward reference count. |
| if (lastUser->hasTrait<OpTrait::ReturnLike>()) |
| continue; |
| |
| // We can't currently handle other types of terminators. |
| if (lastUser->hasTrait<OpTrait::IsTerminator>()) |
| return lastUser->emitError() << "async reference counting can't handle " |
| "terminators that are not ReturnLike"; |
| |
| // Add a drop_ref immediately after the last user. |
| builder.setInsertionPointAfter(lastUser); |
| builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1)); |
| } |
| |
| // ------------------------------------------------------------------------ // |
| // Find blocks where the `value` is in `liveOut` set, however it is not in |
| // the `liveIn` set of all successors. If the `value` is not in the successor |
| // `liveIn` set, we add a `drop_ref` to the beginning of it. |
| // ------------------------------------------------------------------------ // |
| |
| // Successors that we'll need a `drop_ref` for the `value`. |
| llvm::SmallSet<Block *, 4> dropRefSuccessors; |
| |
| for (Block &block : definingRegion->getBlocks()) { |
| const LivenessBlockInfo *blockLiveness = liveness.getLiveness(&block); |
| |
| // Skip the block if value is not in the `liveOut` set. |
| if (!blockLiveness->isLiveOut(value)) |
| continue; |
| |
| // Find successors that do not have `value` in the `liveIn` set. |
| for (Block *successor : block.getSuccessors()) { |
| const LivenessBlockInfo *succLiveness = liveness.getLiveness(successor); |
| |
| if (!succLiveness->isLiveIn(value)) |
| dropRefSuccessors.insert(successor); |
| } |
| } |
| |
| // Drop reference in all successor blocks that do not have the `value` in |
| // their `liveIn` set. |
| for (Block *dropRefSuccessor : dropRefSuccessors) { |
| builder.setInsertionPointToStart(dropRefSuccessor); |
| builder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1)); |
| } |
| |
| // ------------------------------------------------------------------------ // |
| // Find all `async.execute` operation that take `value` as an operand |
| // (dependency token or async value), or capture implicitly by the nested |
| // region. Each `async.execute` operation will require `add_ref` operation |
| // to keep all captured values alive until it will finish its execution. |
| // ------------------------------------------------------------------------ // |
| |
| llvm::SmallSet<ExecuteOp, 4> executeOperations; |
| |
| auto trackAsyncExecute = [&](Operation *op) { |
| if (auto execute = dyn_cast<ExecuteOp>(op)) |
| executeOperations.insert(execute); |
| }; |
| |
| for (Operation *user : value.getUsers()) { |
| // Follow parent operations up until the operation in the `definingRegion`. |
| while (user->getParentRegion() != definingRegion) { |
| trackAsyncExecute(user); |
| user = user->getParentOp(); |
| assert(user != nullptr && "value user lies outside of the value region"); |
| } |
| |
| // Don't forget to process the parent in the `definingRegion` (can be the |
| // original user operation itself). |
| trackAsyncExecute(user); |
| } |
| |
| // Process all `async.execute` operations capturing `value`. |
| for (ExecuteOp execute : executeOperations) { |
| // Add a reference before the execute operation to keep the reference |
| // counted alive before the async region completes execution. |
| builder.setInsertionPoint(execute.getOperation()); |
| builder.create<AddRefOp>(loc, value, IntegerAttr::get(i32, 1)); |
| |
| // Drop the reference inside the async region before completion. |
| OpBuilder executeBuilder = OpBuilder::atBlockTerminator(execute.getBody()); |
| executeBuilder.create<DropRefOp>(loc, value, IntegerAttr::get(i32, 1)); |
| } |
| |
| return success(); |
| } |
| |
| void AsyncRefCountingPass::runOnFunction() { |
| FuncOp func = getFunction(); |
| |
| // Check that we do not have explicit `add_ref` or `drop_ref` in the IR |
| // because otherwise automatic reference counting will produce incorrect |
| // results. |
| WalkResult refCountingWalk = func.walk([&](Operation *op) -> WalkResult { |
| if (isa<AddRefOp, DropRefOp>(op)) |
| return op->emitError() << "explicit reference counting is not supported"; |
| return WalkResult::advance(); |
| }); |
| |
| if (refCountingWalk.wasInterrupted()) |
| signalPassFailure(); |
| |
| // Add reference counting to block arguments. |
| WalkResult blockWalk = func.walk([&](Block *block) -> WalkResult { |
| for (BlockArgument arg : block->getArguments()) |
| if (isRefCounted(arg.getType())) |
| if (failed(addAutomaticRefCounting(arg))) |
| return WalkResult::interrupt(); |
| |
| return WalkResult::advance(); |
| }); |
| |
| if (blockWalk.wasInterrupted()) |
| signalPassFailure(); |
| |
| // Add reference counting to operation results. |
| WalkResult opWalk = func.walk([&](Operation *op) -> WalkResult { |
| for (unsigned i = 0; i < op->getNumResults(); ++i) |
| if (isRefCounted(op->getResultTypes()[i])) |
| if (failed(addAutomaticRefCounting(op->getResult(i)))) |
| return WalkResult::interrupt(); |
| |
| return WalkResult::advance(); |
| }); |
| |
| if (opWalk.wasInterrupted()) |
| signalPassFailure(); |
| } |
| |
| std::unique_ptr<OperationPass<FuncOp>> mlir::createAsyncRefCountingPass() { |
| return std::make_unique<AsyncRefCountingPass>(); |
| } |