| //===- Async.cpp - MLIR Async Operations ----------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Async/IR/Async.h" |
| |
| #include "mlir/IR/DialectImplementation.h" |
| #include "llvm/ADT/TypeSwitch.h" |
| |
| using namespace mlir; |
| using namespace mlir::async; |
| |
| void AsyncDialect::initialize() { |
| addOperations< |
| #define GET_OP_LIST |
| #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" |
| >(); |
| addTypes<TokenType>(); |
| addTypes<ValueType>(); |
| addTypes<GroupType>(); |
| } |
| |
| /// Parse a type registered to this dialect. |
| Type AsyncDialect::parseType(DialectAsmParser &parser) const { |
| StringRef keyword; |
| if (parser.parseKeyword(&keyword)) |
| return Type(); |
| |
| if (keyword == "token") |
| return TokenType::get(getContext()); |
| |
| if (keyword == "value") { |
| Type ty; |
| if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) { |
| parser.emitError(parser.getNameLoc(), "failed to parse async value type"); |
| return Type(); |
| } |
| return ValueType::get(ty); |
| } |
| |
| parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword; |
| return Type(); |
| } |
| |
| /// Print a type registered to this dialect. |
| void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const { |
| TypeSwitch<Type>(type) |
| .Case<TokenType>([&](TokenType) { os << "token"; }) |
| .Case<ValueType>([&](ValueType valueTy) { |
| os << "value<"; |
| os.printType(valueTy.getValueType()); |
| os << '>'; |
| }) |
| .Case<GroupType>([&](GroupType) { os << "group"; }) |
| .Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); }); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// ValueType |
| //===----------------------------------------------------------------------===// |
| |
| namespace mlir { |
| namespace async { |
| namespace detail { |
| |
| // Storage for `async.value<T>` type, the only member is the wrapped type. |
| struct ValueTypeStorage : public TypeStorage { |
| ValueTypeStorage(Type valueType) : valueType(valueType) {} |
| |
| /// The hash key used for uniquing. |
| using KeyTy = Type; |
| bool operator==(const KeyTy &key) const { return key == valueType; } |
| |
| /// Construction. |
| static ValueTypeStorage *construct(TypeStorageAllocator &allocator, |
| Type valueType) { |
| return new (allocator.allocate<ValueTypeStorage>()) |
| ValueTypeStorage(valueType); |
| } |
| |
| Type valueType; |
| }; |
| |
| } // namespace detail |
| } // namespace async |
| } // namespace mlir |
| |
| ValueType ValueType::get(Type valueType) { |
| return Base::get(valueType.getContext(), valueType); |
| } |
| |
| Type ValueType::getValueType() { return getImpl()->valueType; } |
| |
| //===----------------------------------------------------------------------===// |
| // YieldOp |
| //===----------------------------------------------------------------------===// |
| |
| static LogicalResult verify(YieldOp op) { |
| // Get the underlying value types from async values returned from the |
| // parent `async.execute` operation. |
| auto executeOp = op->getParentOfType<ExecuteOp>(); |
| auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) { |
| return result.getType().cast<ValueType>().getValueType(); |
| }); |
| |
| if (op.getOperandTypes() != types) |
| return op.emitOpError("operand types do not match the types returned from " |
| "the parent ExecuteOp"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// ExecuteOp |
| //===----------------------------------------------------------------------===// |
| |
| constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes"; |
| |
| void ExecuteOp::getNumRegionInvocations( |
| ArrayRef<Attribute> operands, SmallVectorImpl<int64_t> &countPerRegion) { |
| (void)operands; |
| assert(countPerRegion.empty()); |
| countPerRegion.push_back(1); |
| } |
| |
| void ExecuteOp::getSuccessorRegions(Optional<unsigned> index, |
| ArrayRef<Attribute> operands, |
| SmallVectorImpl<RegionSuccessor> ®ions) { |
| // The `body` region branch back to the parent operation. |
| if (index.hasValue()) { |
| assert(*index == 0); |
| regions.push_back(RegionSuccessor(getResults())); |
| return; |
| } |
| |
| // Otherwise the successor is the body region. |
| regions.push_back(RegionSuccessor(&body())); |
| } |
| |
| void ExecuteOp::build(OpBuilder &builder, OperationState &result, |
| TypeRange resultTypes, ValueRange dependencies, |
| ValueRange operands, BodyBuilderFn bodyBuilder) { |
| |
| result.addOperands(dependencies); |
| result.addOperands(operands); |
| |
| // Add derived `operand_segment_sizes` attribute based on parsed operands. |
| int32_t numDependencies = dependencies.size(); |
| int32_t numOperands = operands.size(); |
| auto operandSegmentSizes = DenseIntElementsAttr::get( |
| VectorType::get({2}, IntegerType::get(32, result.getContext())), |
| {numDependencies, numOperands}); |
| result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); |
| |
| // First result is always a token, and then `resultTypes` wrapped into |
| // `async.value`. |
| result.addTypes({TokenType::get(result.getContext())}); |
| for (Type type : resultTypes) |
| result.addTypes(ValueType::get(type)); |
| |
| // Add a body region with block arguments as unwrapped async value operands. |
| Region *bodyRegion = result.addRegion(); |
| bodyRegion->push_back(new Block); |
| Block &bodyBlock = bodyRegion->front(); |
| for (Value operand : operands) { |
| auto valueType = operand.getType().dyn_cast<ValueType>(); |
| bodyBlock.addArgument(valueType ? valueType.getValueType() |
| : operand.getType()); |
| } |
| |
| // Create the default terminator if the builder is not provided and if the |
| // expected result is empty. Otherwise, leave this to the caller |
| // because we don't know which values to return from the execute op. |
| if (resultTypes.empty() && !bodyBuilder) { |
| OpBuilder::InsertionGuard guard(builder); |
| builder.setInsertionPointToStart(&bodyBlock); |
| builder.create<async::YieldOp>(result.location, ValueRange()); |
| } else if (bodyBuilder) { |
| OpBuilder::InsertionGuard guard(builder); |
| builder.setInsertionPointToStart(&bodyBlock); |
| bodyBuilder(builder, result.location, bodyBlock.getArguments()); |
| } |
| } |
| |
| static void print(OpAsmPrinter &p, ExecuteOp op) { |
| p << op.getOperationName(); |
| |
| // [%tokens,...] |
| if (!op.dependencies().empty()) |
| p << " [" << op.dependencies() << "]"; |
| |
| // (%value as %unwrapped: !async.value<!arg.type>, ...) |
| if (!op.operands().empty()) { |
| p << " ("; |
| llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable { |
| p << operand << " as " << op.body().front().getArgument(n++) << ": " |
| << operand.getType(); |
| }); |
| p << ")"; |
| } |
| |
| // -> (!async.value<!return.type>, ...) |
| p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1)); |
| p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr}); |
| p.printRegion(op.body(), /*printEntryBlockArgs=*/false); |
| } |
| |
| static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) { |
| MLIRContext *ctx = result.getContext(); |
| |
| // Sizes of parsed variadic operands, will be updated below after parsing. |
| int32_t numDependencies = 0; |
| int32_t numOperands = 0; |
| |
| auto tokenTy = TokenType::get(ctx); |
| |
| // Parse dependency tokens. |
| if (succeeded(parser.parseOptionalLSquare())) { |
| SmallVector<OpAsmParser::OperandType, 4> tokenArgs; |
| if (parser.parseOperandList(tokenArgs) || |
| parser.resolveOperands(tokenArgs, tokenTy, result.operands) || |
| parser.parseRSquare()) |
| return failure(); |
| |
| numDependencies = tokenArgs.size(); |
| } |
| |
| // Parse async value operands (%value as %unwrapped : !async.value<!type>). |
| SmallVector<OpAsmParser::OperandType, 4> valueArgs; |
| SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs; |
| SmallVector<Type, 4> valueTypes; |
| SmallVector<Type, 4> unwrappedTypes; |
| |
| if (succeeded(parser.parseOptionalLParen())) { |
| auto argsLoc = parser.getCurrentLocation(); |
| |
| // Parse a single instance of `%value as %unwrapped : !async.value<!type>`. |
| auto parseAsyncValueArg = [&]() -> ParseResult { |
| if (parser.parseOperand(valueArgs.emplace_back()) || |
| parser.parseKeyword("as") || |
| parser.parseOperand(unwrappedArgs.emplace_back()) || |
| parser.parseColonType(valueTypes.emplace_back())) |
| return failure(); |
| |
| auto valueTy = valueTypes.back().dyn_cast<ValueType>(); |
| unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type()); |
| |
| return success(); |
| }; |
| |
| // If the next token is `)` skip async value arguments parsing. |
| if (failed(parser.parseOptionalRParen())) { |
| do { |
| if (parseAsyncValueArg()) |
| return failure(); |
| } while (succeeded(parser.parseOptionalComma())); |
| |
| if (parser.parseRParen() || |
| parser.resolveOperands(valueArgs, valueTypes, argsLoc, |
| result.operands)) |
| return failure(); |
| } |
| |
| numOperands = valueArgs.size(); |
| } |
| |
| // Add derived `operand_segment_sizes` attribute based on parsed operands. |
| auto operandSegmentSizes = DenseIntElementsAttr::get( |
| VectorType::get({2}, parser.getBuilder().getI32Type()), |
| {numDependencies, numOperands}); |
| result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes); |
| |
| // Parse the types of results returned from the async execute op. |
| SmallVector<Type, 4> resultTypes; |
| if (parser.parseOptionalArrowTypeList(resultTypes)) |
| return failure(); |
| |
| // Async execute first result is always a completion token. |
| parser.addTypeToList(tokenTy, result.types); |
| parser.addTypesToList(resultTypes, result.types); |
| |
| // Parse operation attributes. |
| NamedAttrList attrs; |
| if (parser.parseOptionalAttrDictWithKeyword(attrs)) |
| return failure(); |
| result.addAttributes(attrs); |
| |
| // Parse asynchronous region. |
| Region *body = result.addRegion(); |
| if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs}, |
| /*argTypes=*/{unwrappedTypes}, |
| /*enableNameShadowing=*/false)) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static LogicalResult verify(ExecuteOp op) { |
| // Unwrap async.execute value operands types. |
| auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) { |
| return operand.getType().cast<ValueType>().getValueType(); |
| }); |
| |
| // Verify that unwrapped argument types matches the body region arguments. |
| if (op.body().getArgumentTypes() != unwrappedTypes) |
| return op.emitOpError("async body region argument types do not match the " |
| "execute operation arguments types"); |
| |
| return success(); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| /// AwaitOp |
| //===----------------------------------------------------------------------===// |
| |
| void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand, |
| ArrayRef<NamedAttribute> attrs) { |
| result.addOperands({operand}); |
| result.attributes.append(attrs.begin(), attrs.end()); |
| |
| // Add unwrapped async.value type to the returned values types. |
| if (auto valueType = operand.getType().dyn_cast<ValueType>()) |
| result.addTypes(valueType.getValueType()); |
| } |
| |
| static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType, |
| Type &resultType) { |
| if (parser.parseType(operandType)) |
| return failure(); |
| |
| // Add unwrapped async.value type to the returned values types. |
| if (auto valueType = operandType.dyn_cast<ValueType>()) |
| resultType = valueType.getValueType(); |
| |
| return success(); |
| } |
| |
| static void printAwaitResultType(OpAsmPrinter &p, Operation *op, |
| Type operandType, Type resultType) { |
| p << operandType; |
| } |
| |
| static LogicalResult verify(AwaitOp op) { |
| Type argType = op.operand().getType(); |
| |
| // Awaiting on a token does not have any results. |
| if (argType.isa<TokenType>() && !op.getResultTypes().empty()) |
| return op.emitOpError("awaiting on a token must have empty result"); |
| |
| // Awaiting on a value unwraps the async value type. |
| if (auto value = argType.dyn_cast<ValueType>()) { |
| if (*op.getResultType() != value.getValueType()) |
| return op.emitOpError() |
| << "result type " << *op.getResultType() |
| << " does not match async value type " << value.getValueType(); |
| } |
| |
| return success(); |
| } |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc" |