| //===- LinalgOps.cpp - Implementation of the linalg operations ------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements the Linalg operations. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.h" |
| |
| #include "mlir/Dialect/Affine/IR/AffineOps.h" |
| #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" |
| #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" |
| #include "mlir/Dialect/StandardOps/IR/Ops.h" |
| #include "mlir/IR/Matchers.h" |
| #include "mlir/IR/OpImplementation.h" |
| #include "mlir/IR/PatternMatch.h" |
| |
| #include "llvm/ADT/DenseMap.h" |
| #include "llvm/ADT/SetVector.h" |
| #include "llvm/ADT/StringSet.h" |
| #include "llvm/Support/FormatVariadic.h" |
| #include "llvm/Support/MathExtras.h" |
| #include "llvm/Support/raw_ostream.h" |
| |
| using namespace mlir; |
| using namespace mlir::linalg; |
| |
| /// Fully compose map with operands and canonicalize the result. |
| /// Return the `createOrFold`'ed AffineApply op. |
| static Value createFoldedComposedAffineApply(OpBuilder &b, Location loc, |
| AffineMap map, |
| ValueRange operandsRef) { |
| SmallVector<Value, 4> operands(operandsRef.begin(), operandsRef.end()); |
| fullyComposeAffineMapAndOperands(&map, &operands); |
| canonicalizeMapAndOperands(&map, &operands); |
| return b.createOrFold<AffineApplyOp>(loc, map, operands); |
| } |
| |
| SmallVector<Value, 4> mlir::linalg::applyMapToValues(OpBuilder &b, Location loc, |
| AffineMap map, |
| ValueRange values) { |
| SmallVector<Value, 4> res; |
| res.reserve(map.getNumResults()); |
| unsigned numDims = map.getNumDims(), numSym = map.getNumSymbols(); |
| // For each `expr` in `map`, applies the `expr` to the values extracted from |
| // ranges. If the resulting application can be folded into a Value, the |
| // folding occurs eagerly. |
| for (auto expr : map.getResults()) { |
| AffineMap map = AffineMap::get(numDims, numSym, expr); |
| res.push_back(createFoldedComposedAffineApply(b, loc, map, values)); |
| } |
| return res; |
| } |
| |
| SmallVector<Value, 4> LinalgOp::createFlatListOfOperandDims(OpBuilder &b, |
| Location loc) { |
| SmallVector<Value, 4> res; |
| for (Value v : getShapedOperands()) { |
| ShapedType t = v.getType().template cast<ShapedType>(); |
| for (unsigned i = 0, e = t.getRank(); i < e; ++i) |
| res.push_back(b.create<DimOp>(loc, v, i)); |
| } |
| return res; |
| } |
| |
| SmallVector<Range, 4> LinalgOp::createLoopRanges(OpBuilder &b, Location loc) { |
| AffineMap map = getLoopsToShapesMap(); |
| unsigned numDims = map.getNumDims(), numRes = map.getNumResults(); |
| auto viewSizes = createFlatListOfOperandDims(b, loc); |
| SmallVector<Range, 4> res(numDims); |
| Value zeroVal = b.create<ConstantIndexOp>(loc, 0); |
| Value oneVal = b.create<ConstantIndexOp>(loc, 1); |
| for (unsigned idx = 0; idx < numRes; ++idx) { |
| auto result = map.getResult(idx); |
| if (auto d = result.dyn_cast<AffineDimExpr>()) { |
| if (res[d.getPosition()].offset) |
| continue; |
| res[d.getPosition()] = Range{zeroVal, viewSizes[idx], oneVal}; |
| } |
| } |
| return res; |
| } |
| |
| /// Forward declarations. |
| template <typename NamedStructuredOpType> |
| static void buildNamedStructuredOpRegionAndAttributes( |
| OpBuilder &opBuilder, OperationState &result, TypeRange inputTypes, |
| TypeRange outputBufferTypes, TypeRange initTensorTypes, |
| TypeRange resultTypes); |
| |
| static ParseResult |
| parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, |
| SmallVectorImpl<Type> &inputTypes, |
| SmallVectorImpl<Type> &outputBufferTypes, |
| SmallVectorImpl<Type> &initTensorTypes); |
| |
| template <typename NamedStructuredOpType> |
| static ParseResult |
| parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, |
| TypeRange inputTypes, TypeRange outputBufferTypes, |
| TypeRange initTensorTypes, TypeRange resultTypes); |
| static ParseResult |
| parseNamedStructuredOpResults(OpAsmParser &parser, |
| SmallVectorImpl<Type> &resultTypes); |
| |
| template <typename NamedStructuredOpType> |
| static ParseResult parseNamedStructuredOp(OpAsmParser &parser, |
| OperationState &result); |
| |
| template <typename NamedStructuredOpType> |
| static void printCommonStructuredOpParts(OpAsmPrinter &p, |
| NamedStructuredOpType op); |
| |
| static void printNamedStructuredOpResults(OpAsmPrinter &p, |
| TypeRange resultTypes); |
| |
| template <typename NamedStructuredOpType> |
| static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op); |
| |
| template <typename NamedStructuredOpType> |
| static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op); |
| |
| /// This is a common class used for patterns of the form |
| /// ``` |
| /// someop(memrefcast) -> someop |
| /// ``` |
| /// It folds the source of the memref_cast into the root operation directly. |
| static LogicalResult foldMemRefCast(Operation *op) { |
| bool folded = false; |
| for (OpOperand &operand : op->getOpOperands()) { |
| auto castOp = operand.get().getDefiningOp<MemRefCastOp>(); |
| if (castOp && canFoldIntoConsumerOp(castOp)) { |
| operand.set(castOp.getOperand()); |
| folded = true; |
| } |
| } |
| return success(folded); |
| } |
| |
| ///////////////////// Operations defined with Tablegen ///////////////////////// |
| // For such operations that do not correspond to library calls (i.e. defined in |
| // LinalgOps.td), we define an overloaded `print` function and a |
| // parse`className` function. |
| |
| //===----------------------------------------------------------------------===// |
| // GenericOps |
| //===----------------------------------------------------------------------===// |
| void GenericOp::build( |
| OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, |
| ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, |
| ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, |
| StringRef doc, StringRef libraryCall, |
| function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { |
| build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors, |
| builder.getAffineMapArrayAttr(indexingMaps), |
| builder.getStrArrayAttr(iteratorTypes), |
| doc.empty() ? StringAttr() : builder.getStringAttr(doc), |
| libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), |
| ArrayAttr()); |
| if (!bodyBuild) |
| return; |
| |
| SmallVector<Type, 4> blockArgTypes; |
| for (ValueRange container : {inputs, outputBuffers, initTensors}) |
| for (Value v : container) |
| blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType()); |
| |
| OpBuilder::InsertionGuard guard(builder); |
| auto ®ion = *result.regions.front(); |
| Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); |
| bodyBuild(builder, result.location, bodyBlock->getArguments()); |
| } |
| |
| void GenericOp::build( |
| OpBuilder &builder, OperationState &result, ValueRange inputs, |
| ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps, |
| ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, |
| function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { |
| build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{}, |
| indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild); |
| } |
| |
| void GenericOp::build( |
| OpBuilder &builder, OperationState &result, ValueRange inputs, |
| ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps, |
| ArrayRef<StringRef> iteratorTypes, |
| function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { |
| build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes, |
| /*doc=*/"", |
| /*libraryCall=*/"", bodyBuild); |
| } |
| |
| void GenericOp::build( |
| OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, |
| ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, |
| ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, |
| function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) { |
| build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors, |
| indexingMaps, iteratorTypes, |
| /*doc=*/"", |
| /*libraryCall=*/"", bodyBuild); |
| } |
| void IndexedGenericOp::build( |
| OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, |
| ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, |
| ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, |
| StringRef doc, StringRef libraryCall, |
| function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> |
| bodyBuild) { |
| build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors, |
| builder.getAffineMapArrayAttr(indexingMaps), |
| builder.getStrArrayAttr(iteratorTypes), |
| doc.empty() ? StringAttr() : builder.getStringAttr(doc), |
| libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), |
| ArrayAttr()); |
| if (!bodyBuild) |
| return; |
| |
| unsigned nLoops = iteratorTypes.size(); |
| SmallVector<Type, 4> blockArgTypes(nLoops, builder.getIndexType()); |
| for (ValueRange container : {inputs, outputBuffers, initTensors}) |
| for (Value v : container) |
| blockArgTypes.push_back(v.getType().cast<ShapedType>().getElementType()); |
| |
| OpBuilder::InsertionGuard guard(builder); |
| auto ®ion = *result.regions.front(); |
| Block *bodyBlock = builder.createBlock(®ion, region.end(), blockArgTypes); |
| bodyBuild(builder, result.location, |
| bodyBlock->getArguments().take_front(nLoops), |
| bodyBlock->getArguments().drop_front(nLoops)); |
| } |
| |
| void IndexedGenericOp::build( |
| OpBuilder &builder, OperationState &result, ValueRange inputs, |
| ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps, |
| ArrayRef<StringRef> iteratorTypes, StringRef doc, StringRef libraryCall, |
| function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> |
| bodyBuild) { |
| build(builder, result, TypeRange{}, inputs, outputBuffers, ValueRange{}, |
| indexingMaps, iteratorTypes, doc, libraryCall, bodyBuild); |
| } |
| |
| void IndexedGenericOp::build( |
| OpBuilder &builder, OperationState &result, ValueRange inputs, |
| ValueRange outputBuffers, ArrayRef<AffineMap> indexingMaps, |
| ArrayRef<StringRef> iteratorTypes, |
| function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> |
| bodyBuild) { |
| build(builder, result, inputs, outputBuffers, indexingMaps, iteratorTypes, |
| /*doc=*/"", /*libraryCall=*/"", bodyBuild); |
| } |
| |
| void IndexedGenericOp::build( |
| OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, |
| ValueRange inputs, ValueRange outputBuffers, ValueRange initTensors, |
| ArrayRef<AffineMap> indexingMaps, ArrayRef<StringRef> iteratorTypes, |
| function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> |
| bodyBuild) { |
| build(builder, result, resultTensorTypes, inputs, outputBuffers, initTensors, |
| indexingMaps, iteratorTypes, |
| /*doc=*/"", |
| /*libraryCall=*/"", bodyBuild); |
| } |
| |
| template <typename GenericOpType> |
| static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { |
| p << op.getOperationName() << " "; |
| |
| // Print extra attributes. |
| auto genericAttrNames = op.linalgTraitAttrNames(); |
| |
| llvm::StringSet<> genericAttrNamesSet; |
| genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end()); |
| SmallVector<NamedAttribute, 8> genericAttrs; |
| for (auto attr : op.getAttrs()) |
| if (genericAttrNamesSet.count(attr.first.strref()) > 0) |
| genericAttrs.push_back(attr); |
| if (!genericAttrs.empty()) { |
| auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext()); |
| p << genericDictAttr; |
| } |
| |
| // Printing is shared with named ops, except for the region and attributes |
| printCommonStructuredOpParts(p, op); |
| |
| genericAttrNames.push_back("operand_segment_sizes"); |
| genericAttrNamesSet.insert(genericAttrNames.back()); |
| |
| bool hasExtraAttrs = false; |
| for (NamedAttribute n : op.getAttrs()) { |
| if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.first.strref()))) |
| break; |
| } |
| if (hasExtraAttrs) { |
| p << " attrs = "; |
| p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/genericAttrNames); |
| } |
| |
| // Print region. |
| if (!op.region().empty()) |
| p.printRegion(op.region()); |
| |
| // Print results. |
| printNamedStructuredOpResults(p, op.result_tensors().getTypes()); |
| } |
| |
| static void print(OpAsmPrinter &p, GenericOp op) { printGenericOp(p, op); } |
| |
| static void print(OpAsmPrinter &p, IndexedGenericOp op) { |
| printGenericOp(p, op); |
| } |
| |
| static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) { |
| DictionaryAttr dictAttr; |
| // Parse the core linalg traits that must check into a dictAttr. |
| // The name is unimportant as we will overwrite result.attributes. |
| // The core linalg traits must contain the information necessary to pass the |
| // verifier. |
| if (parser.parseAttribute(dictAttr, "_", result.attributes)) |
| return failure(); |
| result.attributes.assign(dictAttr.getValue().begin(), |
| dictAttr.getValue().end()); |
| |
| // Parsing is shared with named ops, except for the region. |
| SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes; |
| if (parseCommonStructuredOpParts(parser, result, inputTypes, |
| outputBufferTypes, initTensorTypes)) |
| return failure(); |
| |
| // Optional attributes may be added. |
| if (succeeded(parser.parseOptionalKeyword("attrs"))) |
| if (failed(parser.parseEqual()) || |
| failed(parser.parseOptionalAttrDict(result.attributes))) |
| return failure(); |
| |
| SmallVector<OpAsmParser::OperandType, 8> regionOperands; |
| std::unique_ptr<Region> region = std::make_unique<Region>(); |
| SmallVector<Type, 8> operandTypes, regionTypes; |
| if (parser.parseRegion(*region, regionOperands, regionTypes)) |
| return failure(); |
| result.addRegion(std::move(region)); |
| |
| // Generic ops may specify that a subset of its outputs are tensors. Such |
| // outputs are specified in the result type. |
| // TODO: may need to move output parsing before region parsing. |
| // Need to wait for declarative assembly resolution to decide. |
| SmallVector<Type, 1> outputTensorsTypes; |
| if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) |
| return failure(); |
| result.addTypes(outputTensorsTypes); |
| |
| return success(); |
| } |
| |
| static void getGenericEffectsImpl( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects, |
| ValueRange results, ValueRange inputBuffers, ValueRange outputBuffers) { |
| for (Value value : results) { |
| effects.emplace_back(MemoryEffects::Allocate::get(), value, |
| SideEffects::DefaultResource::get()); |
| } |
| for (Value value : inputBuffers) { |
| effects.emplace_back(MemoryEffects::Read::get(), value, |
| SideEffects::DefaultResource::get()); |
| } |
| for (Value value : outputBuffers) { |
| effects.emplace_back(MemoryEffects::Read::get(), value, |
| SideEffects::DefaultResource::get()); |
| effects.emplace_back(MemoryEffects::Write::get(), value, |
| SideEffects::DefaultResource::get()); |
| } |
| } |
| |
| void GenericOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| getGenericEffectsImpl(effects, getOperation()->getResults(), |
| getInputBuffers(), getOutputBuffers()); |
| } |
| |
| void IndexedGenericOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| getGenericEffectsImpl(effects, getOperation()->getResults(), |
| getInputBuffers(), getOutputBuffers()); |
| } |
| |
| namespace { |
| |
| template <typename GenericOpType> |
| struct BlockArgsVerifier { |
| static LogicalResult verify(GenericOpType op, Block &block); |
| }; |
| |
| template <typename GenericOpType> |
| LogicalResult BlockArgsVerifier<GenericOpType>::verify(GenericOpType op, |
| Block &block) { |
| auto nOperands = op.getNumOperands(); |
| if (block.getNumArguments() != nOperands) |
| return op.emitOpError("expected number of block arguments to match number " |
| "of operands"); |
| |
| // Note: the number and type of yield values are checked in the YieldOp. |
| auto nInputViews = op.getNumInputs(); |
| for (unsigned i = 0; i < nOperands; ++i) { |
| auto viewType = op.getShapedType(i); |
| if (viewType.getElementType() != block.getArgument(i).getType()) |
| return op.emitOpError("expected block argument ") |
| << (i + 1) << " of the same type as elemental type of " |
| << ((i < nInputViews) ? "input " : "output ") |
| << "operand: " << viewType; |
| } |
| return success(); |
| } |
| |
| template <> |
| LogicalResult BlockArgsVerifier<IndexedGenericOp>::verify(IndexedGenericOp op, |
| Block &block) { |
| auto nInputViews = op.getNumInputs(); |
| auto nLoops = op.getNumLoops(); |
| auto nOperands = op.getNumOperands(); |
| if (block.getNumArguments() != nOperands + nLoops) |
| return op.emitOpError( |
| "expected number of block arguments to match number of operands + " |
| "number of loops"); |
| |
| // Note: the number and type of yield values are checked in the YieldOp. |
| for (unsigned i = 0; i < nLoops; ++i) |
| if (!block.getArgument(i).getType().isIndex()) |
| return op.emitOpError("expected block argument ") |
| << (i + 1) << " to be an index"; |
| |
| for (unsigned i = 0; i < nOperands; ++i) { |
| unsigned memrefArgIndex = i + nLoops; |
| auto viewType = op.getShapedType(i); |
| if (viewType.getElementType() != |
| block.getArgument(memrefArgIndex).getType()) |
| return op.emitOpError("expected block argument ") |
| << (memrefArgIndex + 1) |
| << " of the same type as elemental type of " |
| << ((i < nInputViews) ? "input " : "output ") |
| << "operand: " << viewType; |
| } |
| return success(); |
| } |
| |
| template <typename GenericOpType> |
| struct AnnotationsVerifier { |
| static LogicalResult verify(GenericOpType op) { return success(); } |
| }; |
| |
| template <> |
| LogicalResult AnnotationsVerifier<GenericOp>::verify(GenericOp op) { |
| ArrayAttr sparseAttr = op.sparseAttr(); |
| if (!sparseAttr) |
| return success(); |
| // Verify consistency of sparse annotations. |
| if (!op.hasTensorSemantics()) |
| return op.emitOpError("expected sparse annotations on tensors only"); |
| if (op.getNumOutputs() != 1) |
| return op.emitOpError("expected single output tensor"); |
| unsigned numTensors = op.getNumInputsAndOutputs(); |
| if (sparseAttr.size() != numTensors) |
| return op.emitOpError("expected one sparse annotation for each tensor"); |
| for (unsigned t = 0; t < numTensors; t++) { |
| auto dimAttr = sparseAttr[t].dyn_cast_or_null<ArrayAttr>(); |
| if (!dimAttr) |
| return op.emitOpError("expected sparse annotation array for tensor ") |
| << t; |
| unsigned rank = op.getShapedType(t).getRank(); |
| if (dimAttr.size() != rank) |
| return op.emitOpError("expected sparse annotation with rank ") |
| << rank << " for tensor " << t; |
| // Per-dimension annotations for each tensor consist of only "D" or "S". |
| for (unsigned d = 0; d < rank; d++) { |
| if (isDenseDim(dimAttr[d])) { |
| continue; |
| } else if (isSparseDim(dimAttr[d])) { |
| if (t == numTensors - 1) |
| return op.emitOpError("sparse output tensors not supported (yet)"); |
| continue; |
| } |
| return op.emitOpError("expected sparse annotation at position ") |
| << d << " for tensor " << t; |
| } |
| } |
| return success(); |
| } |
| |
| } // namespace |
| |
| template <typename GenericOpType> |
| static LogicalResult verifyGenericOp(GenericOpType op) { |
| auto nLoops = op.getNumLoops(); |
| |
| if (op.inputs().size() + op.output_buffers().size() + |
| op.init_tensors().size() + op.getNumResults() == |
| 0) |
| return op.emitOpError("expected at least 1 Shaped operand or return"); |
| |
| auto ®ion = op.region(); |
| if (!llvm::hasSingleElement(region)) |
| return op.emitOpError("expected region with 1 block"); |
| if (failed(BlockArgsVerifier<GenericOpType>::verify(op, region.front()))) |
| return failure(); |
| |
| if (op.indexing_maps().size() != op.getNumInputsAndOutputs()) |
| return op.emitOpError("expected the number of indexing_map (") |
| << op.indexing_maps().size() |
| << ") to be equal to the number of inputs and outputs (" |
| << op.getNumInputsAndOutputs() << ")"; |
| |
| SmallVector<AffineMap, 4> indexingMaps; |
| indexingMaps.reserve(op.indexing_maps().size()); |
| for (auto en : llvm::enumerate(op.indexing_maps())) { |
| auto idx = en.index(); |
| auto m = en.value().template cast<AffineMapAttr>().getValue(); |
| indexingMaps.push_back(m); // Save reference to map for further checks. |
| auto view = op.getShapedType(idx); |
| |
| if (m.getNumSymbols() != 0) |
| return op.emitOpError("unexpected symbols in indexing_map #") << idx; |
| |
| if (m.getNumDims() != nLoops) |
| return op.emitOpError("expected indexing_map #") |
| << idx << " to have " << nLoops |
| << " dim(s) to match the number of loops"; |
| |
| if (m.getNumResults() != view.getRank()) |
| return op.emitOpError("expected indexing_map #") |
| << idx << " results to match view rank: " << view; |
| } |
| |
| if (!op.getShapesToLoopsMap()) |
| return op.emitOpError("expected the shape-to-loops map to be non-null"); |
| |
| if (failed(AnnotationsVerifier<GenericOpType>::verify(op))) |
| return failure(); |
| |
| return success(); |
| } |
| |
| static LogicalResult verify(GenericOp op) { return verifyGenericOp(op); } |
| |
| static LogicalResult verify(IndexedGenericOp op) { return verifyGenericOp(op); } |
| |
| //===----------------------------------------------------------------------===// |
| // ReshapeOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Collapse reassociation maps that are used in pair of reshape ops where one |
| /// is a producer and other is the consumer. Only valid to use this method when |
| /// both the producer and consumer are collapsing dimensions or both are |
| /// expanding dimensions. |
| /// |
| /// For example, |
| /// mapsProducer = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>, |
| /// affine_map<(d0, d1, d2, d3, d4) -> (d2)>, |
| /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] |
| /// mapsConsumer = [affine_map<(d0, d1, d2) -> (d0, d1)>, |
| /// affine_map<(d0, d1, d2) -> (d2)>] |
| /// |
| /// is folded into |
| /// |
| /// result = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>, |
| /// affine_map<(d0, d1, d2, d3, d4) -> (d3, d4)>] |
| static ArrayAttr collapseReassociationMaps(ArrayRef<AffineMap> mapsProducer, |
| ArrayRef<AffineMap> mapsConsumer, |
| MLIRContext *context) { |
| // Handle the corner case of the result being a rank 0 shaped type. Return an |
| // emtpy ArrayAttr. |
| if (mapsConsumer.empty() && !mapsProducer.empty()) |
| return ArrayAttr::get(ArrayRef<Attribute>(), context); |
| if (mapsProducer.empty() || mapsConsumer.empty() || |
| mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || |
| mapsProducer.size() != mapsConsumer[0].getNumDims()) |
| return nullptr; |
| unsigned numLhsDims = mapsProducer[0].getNumDims(); |
| unsigned currDim = 0; |
| SmallVector<AffineExpr, 4> reassociations; |
| SmallVector<Attribute, 4> reassociationMaps; |
| for (AffineMap rhs : mapsConsumer) { |
| for (AffineExpr rhsExpr : rhs.getResults()) { |
| AffineDimExpr dimExpr = rhsExpr.cast<AffineDimExpr>(); |
| for (int i = 0, e = mapsProducer[dimExpr.getPosition()].getNumResults(); |
| i < e; ++i) { |
| reassociations.push_back(getAffineDimExpr(currDim++, context)); |
| } |
| } |
| reassociationMaps.push_back(AffineMapAttr::get(AffineMap::get( |
| numLhsDims, /*numSymbols =*/0, reassociations, context))); |
| reassociations.clear(); |
| } |
| return ArrayAttr::get(reassociationMaps, context); |
| } |
| |
| namespace { |
| /// Pattern to collapse producer/consumer reshape ops that are both collapsing |
| /// dimensions or are both expanding dimensions. |
| template <typename ReshapeOpTy> |
| struct CollapseReshapeOps : public OpRewritePattern<ReshapeOpTy> { |
| using OpRewritePattern<ReshapeOpTy>::OpRewritePattern; |
| LogicalResult matchAndRewrite(ReshapeOpTy reshapeOp, |
| PatternRewriter &rewriter) const override { |
| auto srcReshapeOp = reshapeOp.src().template getDefiningOp<ReshapeOpTy>(); |
| if (!srcReshapeOp) |
| return failure(); |
| |
| auto areReshapeOpsFoldable = [](ShapedType largerType, |
| ShapedType intermediateType, |
| ShapedType smallerType) -> bool { |
| return largerType.getRank() > intermediateType.getRank() && |
| intermediateType.getRank() > smallerType.getRank(); |
| }; |
| // Check if producer and consumer are both expanding dims. |
| if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(), |
| srcReshapeOp.getSrcType())) { |
| rewriter.replaceOpWithNewOp<ReshapeOpTy>( |
| reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), |
| collapseReassociationMaps(reshapeOp.getReassociationMaps(), |
| srcReshapeOp.getReassociationMaps(), |
| rewriter.getContext())); |
| return success(); |
| } |
| // Check if producer and consumer are both collapsing dims. |
| if (areReshapeOpsFoldable(srcReshapeOp.getSrcType(), reshapeOp.getSrcType(), |
| reshapeOp.getResultType())) { |
| rewriter.replaceOpWithNewOp<ReshapeOpTy>( |
| reshapeOp, reshapeOp.getResultType(), srcReshapeOp.src(), |
| collapseReassociationMaps(srcReshapeOp.getReassociationMaps(), |
| reshapeOp.getReassociationMaps(), |
| rewriter.getContext())); |
| return success(); |
| } |
| return failure(); |
| } |
| }; |
| } // namespace |
| |
| template <typename ReshapeOpTy> |
| static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, |
| ArrayRef<Attribute> operands) { |
| // Fold producer-consumer reshape ops that where the operand type of the |
| // producer is same as the return type of the consumer. This can only be |
| // verified if the shapes in question are static. |
| ReshapeOpTy reshapeSrcOp = |
| reshapeOp.src().template getDefiningOp<ReshapeOpTy>(); |
| if (reshapeSrcOp && reshapeSrcOp.getSrcType().hasStaticShape() && |
| reshapeOp.getResultType().hasStaticShape() && |
| reshapeSrcOp.getSrcType() == reshapeOp.getResultType()) |
| return reshapeSrcOp.src(); |
| // Reshape of a constant can be replaced with a new constant. |
| if (auto elements = operands.front().dyn_cast_or_null<DenseElementsAttr>()) { |
| return elements.reshape( |
| reshapeOp.getResult().getType().template cast<ShapedType>()); |
| } |
| return nullptr; |
| } |
| |
| /// Return true if the reassociation specification is valid, false otherwise. |
| /// When false, the `invalidIndex` integer pointer is optionally filled with the |
| /// index of the offending reassociation map. |
| static bool isReassociationValid(ArrayRef<AffineMap> reassociation, |
| int *invalidIndex = nullptr) { |
| if (reassociation.empty()) |
| return true; |
| unsigned nDims = reassociation[0].getNumDims(); |
| unsigned nextExpectedDim = 0; |
| for (auto it : llvm::enumerate(reassociation)) { |
| auto m = it.value(); |
| if (m.getNumDims() != nDims || m.getNumSymbols() != 0) { |
| if (invalidIndex) |
| *invalidIndex = it.index(); |
| return false; |
| } |
| for (auto e : m.getResults()) { |
| auto d = e.dyn_cast<AffineDimExpr>(); |
| if (!d || d.getPosition() != nextExpectedDim++) { |
| if (invalidIndex) |
| *invalidIndex = it.index(); |
| return false; |
| } |
| } |
| } |
| if (nextExpectedDim != nDims) { |
| if (invalidIndex) |
| *invalidIndex = reassociation.size() - 1; |
| return false; |
| } |
| return true; |
| } |
| |
| /// Detect whether memref dims [dim, dim + extent) can be reshaped without |
| /// copies. |
| static bool isReshapableDimBand(unsigned dim, unsigned extent, |
| ArrayRef<int64_t> sizes, |
| ArrayRef<AffineExpr> strides) { |
| assert(sizes.size() == strides.size() && "mismatched ranks"); |
| // off by 1 indexing to avoid out of bounds |
| // V |
| for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) { |
| // Only bands of static shapes are reshapable. This is due to the fact that |
| // there is no relation between dynamic sizes and dynamic strides: we do not |
| // have enough information to know whether a "-1" size corresponds to the |
| // proper symbol in the AffineExpr of a stride. |
| if (ShapedType::isDynamic(sizes[dim + 1])) |
| return false; |
| // TODO: Refine this by passing the proper nDims and nSymbols so we can |
| // simplify on the fly and catch more reshapable cases. |
| if (strides[idx] != strides[idx + 1] * sizes[idx + 1]) |
| return false; |
| } |
| return true; |
| } |
| |
| /// Compute the MemRefType obtained by applying the `reassociation` (which is |
| /// expected to be valid) to `type`. |
| /// If `type` is Contiguous MemRefType, this always produce a contiguous |
| /// MemRefType. |
| static MemRefType |
| computeReshapeCollapsedType(MemRefType type, |
| ArrayRef<AffineMap> reassociation) { |
| auto sizes = type.getShape(); |
| AffineExpr offset; |
| SmallVector<AffineExpr, 4> strides; |
| auto status = getStridesAndOffset(type, strides, offset); |
| (void)status; |
| assert(succeeded(status) && "expected strided memref"); |
| |
| SmallVector<int64_t, 4> newSizes; |
| newSizes.reserve(reassociation.size()); |
| SmallVector<AffineExpr, 4> newStrides; |
| newStrides.reserve(reassociation.size()); |
| |
| // Use the fact that reassociation is valid to simplify the logic: only use |
| // each map's rank. |
| assert(isReassociationValid(reassociation) && "invalid reassociation"); |
| unsigned currentDim = 0; |
| for (AffineMap m : reassociation) { |
| unsigned dim = m.getNumResults(); |
| int64_t size = 1; |
| AffineExpr stride = strides[currentDim + dim - 1]; |
| if (!isReshapableDimBand(currentDim, dim, sizes, strides)) { |
| size = ShapedType::kDynamicSize; |
| stride = AffineExpr(); |
| } else { |
| for (unsigned d = 0; d < dim; ++d) |
| size *= sizes[currentDim + d]; |
| } |
| newSizes.push_back(size); |
| newStrides.push_back(stride); |
| currentDim += dim; |
| } |
| |
| // Early-exit: if `type` is contiguous, the result must be contiguous. |
| if (canonicalizeStridedLayout(type).getAffineMaps().empty()) |
| return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({}); |
| |
| // Convert back to int64_t because we don't have enough information to create |
| // new strided layouts from AffineExpr only. This corresponds to a case where |
| // copies may be necessary. |
| int64_t intOffset = ShapedType::kDynamicStrideOrOffset; |
| if (auto o = offset.dyn_cast<AffineConstantExpr>()) |
| intOffset = o.getValue(); |
| SmallVector<int64_t, 4> intStrides; |
| intStrides.reserve(strides.size()); |
| for (auto stride : newStrides) { |
| if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>()) |
| intStrides.push_back(cst.getValue()); |
| else |
| intStrides.push_back(ShapedType::kDynamicStrideOrOffset); |
| } |
| auto layout = |
| makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext()); |
| return canonicalizeStridedLayout( |
| MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout})); |
| } |
| |
| /// Helper functions assert Attribute of the proper type in attr and returns the |
| /// corresponding vector. |
| /// TODO: this should be evolved into a generic |
| /// `getRangeOfType<AffineMap>(ArrayAttr attrs)` that does not copy. |
| static SmallVector<AffineMap, 4> getAffineMaps(ArrayAttr attrs) { |
| return llvm::to_vector<8>(llvm::map_range( |
| attrs, [](Attribute a) { return a.cast<AffineMapAttr>().getValue(); })); |
| } |
| |
| template <typename AffineExprTy> |
| unsigned getMaxPosOfType(ArrayRef<ReassociationExprs> exprArrays) { |
| unsigned pos = 0; |
| for (const auto &exprs : exprArrays) { |
| for (auto expr : exprs) { |
| expr.walk([&pos](AffineExpr e) { |
| if (auto d = e.dyn_cast<AffineExprTy>()) |
| pos = std::max(pos, d.getPosition()); |
| }); |
| } |
| } |
| return pos; |
| } |
| |
| static SmallVector<AffineMap, 4> |
| getSymbolLessAffineMaps(ArrayRef<ReassociationExprs> reassociation) { |
| unsigned maxDim = getMaxPosOfType<AffineDimExpr>(reassociation); |
| assert(getMaxPosOfType<AffineSymbolExpr>(reassociation) == 0 && |
| "Expected symbol-less expressions"); |
| SmallVector<AffineMap, 4> maps; |
| maps.reserve(reassociation.size()); |
| for (const auto &exprs : reassociation) { |
| assert(!exprs.empty()); |
| maps.push_back(AffineMap::get(maxDim + 1, 0, exprs, exprs[0].getContext())); |
| } |
| return maps; |
| } |
| |
| static SmallVector<SmallVector<AffineExpr, 2>, 2> |
| convertReassociationIndicesToMaps( |
| OpBuilder &b, ArrayRef<ReassociationIndices> reassociationIndices) { |
| SmallVector<SmallVector<AffineExpr, 2>, 2> reassociationMaps; |
| for (const auto &indices : reassociationIndices) { |
| SmallVector<AffineExpr, 2> reassociationMap; |
| reassociationMap.reserve(indices.size()); |
| for (int64_t index : indices) |
| reassociationMap.push_back(b.getAffineDimExpr(index)); |
| reassociationMaps.push_back(std::move(reassociationMap)); |
| } |
| return reassociationMaps; |
| } |
| |
| void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result, |
| Value src, |
| ArrayRef<ReassociationExprs> reassociation, |
| ArrayRef<NamedAttribute> attrs) { |
| auto maps = getSymbolLessAffineMaps(reassociation); |
| auto memRefType = src.getType().cast<MemRefType>(); |
| auto resultType = computeReshapeCollapsedType(memRefType, maps); |
| build(b, result, resultType, src, attrs); |
| result.addAttribute(ReshapeOp::getReassociationAttrName(), |
| b.getAffineMapArrayAttr(maps)); |
| } |
| |
| void mlir::linalg::ReshapeOp::build(OpBuilder &b, OperationState &result, |
| Type resultType, Value src, |
| ArrayRef<ReassociationExprs> reassociation, |
| ArrayRef<NamedAttribute> attrs) { |
| auto maps = getSymbolLessAffineMaps(reassociation); |
| build(b, result, resultType, src, attrs); |
| result.addAttribute(ReshapeOp::getReassociationAttrName(), |
| b.getAffineMapArrayAttr(maps)); |
| } |
| |
| Value mlir::linalg::ReshapeOp::getViewSource() { return src(); } |
| |
| // Common verifier for reshape-like types. Fills `expandedType` and |
| // `collapsedType` with the proper `src` or `result` type. |
| template <typename Op, typename T> |
| static LogicalResult verifyReshapeLikeTypes(Op op, T &expandedType, |
| T &collapsedType) { |
| expandedType = op.getSrcType(); |
| collapsedType = op.getResultType(); |
| unsigned expandedRank = expandedType.getRank(); |
| unsigned collapsedRank = collapsedType.getRank(); |
| bool isCollapse = expandedRank > collapsedRank; |
| if (!isCollapse) { |
| std::swap(expandedRank, collapsedRank); |
| std::swap(expandedType, collapsedType); |
| } |
| if (expandedRank == 0) |
| return op.emitOpError("expected non-zero memref ranks"); |
| if (expandedRank == collapsedRank) |
| return op.emitOpError("expected to collapse or expand dims"); |
| |
| if (collapsedRank == 0) { |
| // If collapsed rank is 0, then expanded type must be static shaped and of |
| // sizes 1. |
| if (llvm::any_of(expandedType.getShape(), |
| [](int64_t dim) -> bool { return dim != 1; })) |
| return op.emitOpError( |
| "invalid to reshape tensor/memref with non-unit extent dimensions to " |
| "zero-rank tensor/memref"); |
| return success(); |
| } |
| if (collapsedRank != op.reassociation().size()) |
| return op.emitOpError("expected rank of the collapsed type(") |
| << collapsedRank << ") to be the number of reassociation maps(" |
| << op.reassociation().size() << ")"; |
| auto maps = getAffineMaps(op.reassociation()); |
| for (auto it : llvm::enumerate(maps)) |
| if (it.value().getNumDims() != expandedRank) |
| return op.emitOpError("expected reassociation map #") |
| << it.index() << " of same rank as expanded memref(" |
| << expandedRank << "), but got " << it.value().getNumDims(); |
| int invalidIdx = 0; |
| if (!isReassociationValid(maps, &invalidIdx)) |
| return op.emitOpError("expected reassociation map #") |
| << invalidIdx << " to be valid and contiguous"; |
| return success(); |
| } |
| |
| static LogicalResult verify(ReshapeOp op) { |
| MemRefType expandedType, collapsedType; |
| if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) |
| return failure(); |
| auto maps = getAffineMaps(op.reassociation()); |
| MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps); |
| if (collapsedType != expectedType) |
| return op.emitOpError("expected collapsed type to be ") |
| << expectedType << ", but got " << collapsedType; |
| return success(); |
| } |
| |
| void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, |
| MLIRContext *context) { |
| results.insert<CollapseReshapeOps<ReshapeOp>>(context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // TensorReshapeOp |
| //===----------------------------------------------------------------------===// |
| |
| /// Compute the RankedTensorType obtained by applying `reassociation` to `type`. |
| static RankedTensorType |
| computeTensorReshapeCollapsedType(RankedTensorType type, |
| ArrayRef<AffineMap> reassociation) { |
| auto shape = type.getShape(); |
| SmallVector<int64_t, 4> newShape; |
| newShape.reserve(reassociation.size()); |
| |
| // Use the fact that reassociation is valid to simplify the logic: only use |
| // each map's rank. |
| assert(isReassociationValid(reassociation) && "invalid reassociation"); |
| unsigned currentDim = 0; |
| for (AffineMap m : reassociation) { |
| unsigned dim = m.getNumResults(); |
| auto band = shape.slice(currentDim, dim); |
| int64_t size = 1; |
| if (llvm::is_contained(band, ShapedType::kDynamicSize)) |
| size = ShapedType::kDynamicSize; |
| else |
| for (unsigned d = 0; d < dim; ++d) |
| size *= shape[currentDim + d]; |
| newShape.push_back(size); |
| currentDim += dim; |
| } |
| |
| return RankedTensorType::get(newShape, type.getElementType()); |
| } |
| |
| void mlir::linalg::TensorReshapeOp::build( |
| OpBuilder &b, OperationState &result, Value src, |
| ArrayRef<ReassociationExprs> reassociation, |
| ArrayRef<NamedAttribute> attrs) { |
| auto maps = getSymbolLessAffineMaps(reassociation); |
| auto resultType = computeTensorReshapeCollapsedType( |
| src.getType().cast<RankedTensorType>(), maps); |
| build(b, result, resultType, src, attrs); |
| result.addAttribute(TensorReshapeOp::getReassociationAttrName(), |
| b.getAffineMapArrayAttr(maps)); |
| } |
| |
| void mlir::linalg::TensorReshapeOp::build( |
| OpBuilder &b, OperationState &result, Type resultType, Value src, |
| ArrayRef<ReassociationExprs> reassociation, |
| ArrayRef<NamedAttribute> attrs) { |
| auto maps = getSymbolLessAffineMaps(reassociation); |
| build(b, result, resultType, src, attrs); |
| result.addAttribute(TensorReshapeOp::getReassociationAttrName(), |
| b.getAffineMapArrayAttr(maps)); |
| } |
| |
| static LogicalResult verify(TensorReshapeOp op) { |
| RankedTensorType expandedType, collapsedType; |
| if (failed(verifyReshapeLikeTypes(op, expandedType, collapsedType))) |
| return failure(); |
| auto maps = getAffineMaps(op.reassociation()); |
| // TODO: expanding a ? with a non-constant is under-specified. Error |
| // out. |
| RankedTensorType expectedType = |
| computeTensorReshapeCollapsedType(expandedType, maps); |
| if (collapsedType != expectedType) |
| return op.emitOpError("expected collapsed type to be ") |
| << expectedType << ", but got " << collapsedType; |
| return success(); |
| } |
| |
| namespace { |
| /// Reshape of a splat constant can be replaced with a constant of the result |
| /// type. |
| struct FoldReshapeWithConstant : OpRewritePattern<TensorReshapeOp> { |
| using OpRewritePattern<TensorReshapeOp>::OpRewritePattern; |
| LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, |
| PatternRewriter &rewriter) const override { |
| DenseElementsAttr attr; |
| if (!matchPattern(reshapeOp.src(), m_Constant(&attr))) |
| return failure(); |
| if (!attr || !attr.isSplat()) |
| return failure(); |
| DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( |
| reshapeOp.getResultType(), attr.getRawData(), true); |
| rewriter.replaceOpWithNewOp<ConstantOp>(reshapeOp, newAttr); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| void TensorReshapeOp::getCanonicalizationPatterns( |
| OwningRewritePatternList &results, MLIRContext *context) { |
| results.insert<CollapseReshapeOps<TensorReshapeOp>, FoldReshapeWithConstant>( |
| context); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // SliceOp |
| //===----------------------------------------------------------------------===// |
| void mlir::linalg::SliceOp::build(OpBuilder &b, OperationState &result, |
| Value base, ValueRange indexings) { |
| result.addOperands(base); |
| result.addOperands(indexings); |
| |
| auto memRefType = base.getType().cast<MemRefType>(); |
| int64_t offset; |
| SmallVector<int64_t, 4> strides; |
| auto res = getStridesAndOffset(memRefType, strides, offset); |
| assert(succeeded(res) && strides.size() == indexings.size()); |
| (void)res; |
| |
| unsigned rank = memRefType.getRank(); |
| // TODO: propagate static size and stride information when available. |
| SmallVector<int64_t, 4> sizes(rank, -1); // -1 encodes dynamic size. |
| result.addTypes({MemRefType::Builder(memRefType) |
| .setShape(sizes) |
| .setAffineMaps(makeStridedLinearLayoutMap( |
| strides, offset, b.getContext()))}); |
| } |
| |
| static void print(OpAsmPrinter &p, SliceOp op) { |
| auto indexings = op.indexings(); |
| p << SliceOp::getOperationName() << " " << op.view() << "[" << indexings |
| << "] "; |
| p.printOptionalAttrDict(op.getAttrs()); |
| p << " : " << op.getBaseViewType(); |
| if (!indexings.empty()) |
| p << ", " << op.indexings().getTypes(); |
| p << ", " << op.getType(); |
| } |
| |
| static ParseResult parseSliceOp(OpAsmParser &parser, OperationState &result) { |
| OpAsmParser::OperandType baseInfo; |
| SmallVector<OpAsmParser::OperandType, 8> operands; |
| SmallVector<Type, 8> types; |
| if (parser.parseOperand(baseInfo) || |
| parser.parseOperandList(operands, OpAsmParser::Delimiter::Square) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| parser.parseColonTypeList(types)) |
| return failure(); |
| |
| if (types.size() < 2) |
| return parser.emitError(parser.getCurrentLocation(), |
| "expected at least input and result view types"); |
| |
| ArrayRef<Type> indexingTypes = ArrayRef<Type>(types).drop_front().drop_back(); |
| return failure( |
| parser.resolveOperand(baseInfo, types.front(), result.operands) || |
| (!operands.empty() && |
| parser.resolveOperands(operands, indexingTypes, |
| operands.front().location, result.operands)) || |
| parser.addTypeToList(types.back(), result.types)); |
| } |
| |
| static LogicalResult verify(SliceOp op) { |
| unsigned rank = op.getBaseViewRank(); |
| if (rank != llvm::size(op.indexings())) |
| return op.emitOpError("expected ") |
| << rank << " indexings, got " << llvm::size(op.indexings()); |
| unsigned index = 0; |
| for (auto indexing : op.indexings()) { |
| if (indexing.getType().isa<IndexType>()) |
| --rank; |
| ++index; |
| } |
| if (op.getRank() != rank) |
| return op.emitOpError() << "expected rank of the view(" << op.getRank() |
| << ") to be the number of ranges(" << rank << ")"; |
| return success(); |
| } |
| |
| Value SliceOp::getViewSource() { return view(); } |
| |
| //===----------------------------------------------------------------------===// |
| // YieldOp |
| //===----------------------------------------------------------------------===// |
| |
| static void print(OpAsmPrinter &p, linalg::YieldOp op) { |
| p << op.getOperationName(); |
| if (op.getNumOperands() > 0) |
| p << ' ' << op.getOperands(); |
| p.printOptionalAttrDict(op.getAttrs()); |
| if (op.getNumOperands() > 0) |
| p << " : " << op.getOperandTypes(); |
| } |
| |
| static ParseResult parseYieldOp(OpAsmParser &parser, OperationState &result) { |
| SmallVector<OpAsmParser::OperandType, 2> opInfo; |
| SmallVector<Type, 2> types; |
| llvm::SMLoc loc = parser.getCurrentLocation(); |
| return failure(parser.parseOperandList(opInfo) || |
| parser.parseOptionalAttrDict(result.attributes) || |
| (!opInfo.empty() && parser.parseColonTypeList(types)) || |
| parser.resolveOperands(opInfo, types, loc, result.operands)); |
| } |
| |
| // Check the operand number and types must match the element types of the |
| // LinalgOp interface's shaped operands. |
| static LogicalResult verifyYield(linalg::YieldOp op, |
| LinalgOp linalgOpInterface) { |
| auto nOutputs = linalgOpInterface.getNumOutputs(); |
| if (op.getNumOperands() != nOutputs) |
| return op.emitOpError("expected number of yield values (") |
| << nOutputs << ") to match the number of operands of the enclosing " |
| << "LinalgOp (" << op.getNumOperands() << ")"; |
| |
| for (unsigned i = 0; i != nOutputs; ++i) { |
| auto elementType = |
| linalgOpInterface.getOutputShapedType(i).getElementType(); |
| if (op.getOperand(i).getType() != elementType) |
| return op.emitOpError("type of yield operand ") |
| << (i + 1) << " (" << op.getOperand(i).getType() |
| << ") doesn't match " |
| << "the element type of the enclosing linalg.generic op (" |
| << elementType << ")"; |
| } |
| return success(); |
| } |
| |
| static LogicalResult verify(linalg::YieldOp op) { |
| auto *parentOp = op->getParentOp(); |
| if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty()) |
| return op.emitOpError("expected single non-empty parent region"); |
| |
| if (auto linalgOp = dyn_cast<LinalgOp>(parentOp)) |
| return verifyYield(op, cast<LinalgOp>(parentOp)); |
| |
| return op.emitOpError("expected parent op with LinalgOp interface"); |
| } |
| |
| /////// Operations corresponding to library calls defined with Tablegen //////// |
| |
| void FillOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| effects.emplace_back(MemoryEffects::Write::get(), output(), |
| SideEffects::DefaultResource::get()); |
| } |
| |
| static LogicalResult verify(FillOp op) { |
| auto viewType = op.getOutputShapedType(0); |
| auto fillType = op.value().getType(); |
| if (viewType.getElementType() != fillType) |
| return op.emitOpError("expects fill type to match view elemental type"); |
| return success(); |
| } |
| |
| void CopyOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| effects.emplace_back(MemoryEffects::Read::get(), input(), |
| SideEffects::DefaultResource::get()); |
| effects.emplace_back(MemoryEffects::Write::get(), output(), |
| SideEffects::DefaultResource::get()); |
| } |
| |
| static LogicalResult verify(CopyOp op) { |
| auto outputViewType = op.getOutputShapedType(0); |
| auto inputViewType = op.getInputShapedType(0); |
| if (inputViewType.getElementType() != outputViewType.getElementType()) |
| return op.emitOpError("expects views of the same type"); |
| if (inputViewType.getRank() != outputViewType.getRank()) |
| return op.emitOpError("expects views of the same rank"); |
| auto rank = op.getNumParallelLoops(); |
| auto inputPermutationMap = op.inputPermutation(); |
| if (inputPermutationMap) { |
| if (inputPermutationMap->getNumInputs() != rank) |
| return op.emitOpError("expects optional input_permutation map of rank ") |
| << rank; |
| if (!inputPermutationMap->isPermutation()) |
| return op.emitOpError( |
| "expects optional input_permutation map to be a permutation"); |
| } |
| auto outputPermutationMap = op.outputPermutation(); |
| if (outputPermutationMap) { |
| if (outputPermutationMap->getNumInputs() != rank) |
| return op.emitOpError("expects optional output_permutation map of rank ") |
| << rank; |
| if (!outputPermutationMap->isPermutation()) |
| return op.emitOpError( |
| "expects optional output_permutation map to be a permutation"); |
| } |
| if (rank == 0 && inputPermutationMap) |
| return op.emitOpError("expected no input permutation when rank == 0"); |
| if (rank == 0 && outputPermutationMap) |
| return op.emitOpError("expected no output permutation when rank == 0"); |
| return success(); |
| } |
| |
| template <typename LinalgPoolingOp> |
| static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op, |
| ArrayRef<Attribute> attrs, |
| bool isStride) { |
| auto strideOrDilation = isStride ? "stride" : "dilation"; |
| if (attrs.size() != op.getNumWindowLoops()) |
| return op.emitOpError("expects num ") |
| << strideOrDilation |
| << "s equal to number of window dimensions: " << attrs.size() |
| << " vs " << op.getNumWindowLoops(); |
| return success(); |
| } |
| |
| void ConvOp::getEffects( |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> |
| &effects) { |
| effects.emplace_back(MemoryEffects::Read::get(), input(), |
| SideEffects::DefaultResource::get()); |
| effects.emplace_back(MemoryEffects::Read::get(), filter(), |
| SideEffects::DefaultResource::get()); |
| effects.emplace_back(MemoryEffects::Write::get(), output(), |
| SideEffects::DefaultResource::get()); |
| } |
| |
| static LogicalResult verify(ConvOp op) { |
| auto oType = op.output().getType().cast<MemRefType>(); |
| auto fType = op.filter().getType().cast<MemRefType>(); |
| auto iType = op.input().getType().cast<MemRefType>(); |
| if (oType.getElementType() != iType.getElementType() || |
| oType.getElementType() != fType.getElementType()) |
| return op.emitOpError("expects memref elemental types to match"); |
| if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank()) |
| return op.emitOpError("expects memref ranks to match"); |
| if (oType.getRank() <= 2) |
| return op.emitOpError("expects memref ranks to be greater than 2"); |
| if (auto strides = op.strides()) { |
| if (failed( |
| verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) |
| return failure(); |
| } |
| if (auto dilations = op.dilations()) { |
| if (failed(verifyStrideOrDilation(op, dilations->getValue(), |
| /*isStride=*/false))) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| template <typename PoolingOp> |
| static LogicalResult verifySingleInputPoolingOp(PoolingOp op) { |
| auto inputType = op.input().getType().template cast<MemRefType>(); |
| auto outputType = op.output().getType().template cast<MemRefType>(); |
| if (outputType.getElementType() != inputType.getElementType()) |
| return op.emitOpError("expects memref elemental types to match"); |
| |
| auto windowDimsType = op.windowDims().getType().template cast<MemRefType>(); |
| if (outputType.getRank() != inputType.getRank() || |
| outputType.getRank() != windowDimsType.getRank()) |
| return op.emitOpError("expects memref ranks to match"); |
| |
| if (auto strides = op.strides()) { |
| if (failed( |
| verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) |
| return failure(); |
| } |
| if (auto dilations = op.dilations()) { |
| if (failed(verifyStrideOrDilation(op, dilations->getValue(), |
| /*isStride=*/false))) |
| return failure(); |
| } |
| return success(); |
| } |
| |
| #define DEFINE_POOLING_OP_GET_EFFECTS(OP_NAME) \ |
| void OP_NAME::getEffects( \ |
| SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> \ |
| &effects) { \ |
| effects.emplace_back(MemoryEffects::Read::get(), input(), \ |
| SideEffects::DefaultResource::get()); \ |
| effects.emplace_back(MemoryEffects::Write::get(), output(), \ |
| SideEffects::DefaultResource::get()); \ |
| } |
| |
| static LogicalResult verify(PoolingMaxOp op) { |
| return verifySingleInputPoolingOp(op); |
| } |
| static LogicalResult verify(PoolingMinOp op) { |
| return verifySingleInputPoolingOp(op); |
| } |
| static LogicalResult verify(PoolingSumOp op) { |
| return verifySingleInputPoolingOp(op); |
| } |
| |
| DEFINE_POOLING_OP_GET_EFFECTS(PoolingMaxOp) |
| DEFINE_POOLING_OP_GET_EFFECTS(PoolingMinOp) |
| DEFINE_POOLING_OP_GET_EFFECTS(PoolingSumOp) |
| |
| namespace { |
| struct EraseDeadLinalgOp; |
| struct FoldTensorCastOp; |
| } // namespace |
| |
| #include "mlir/Dialect/Linalg/IR/LinalgStructuredOpsInterfaces.cpp.inc" |
| |
| #include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.cpp.inc" |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" |
| |
| #define GET_OP_CLASSES |
| #include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc" |
| |
| /// Return the dims that are `iteratorTypeName` loops in the LinalgOp `op`. |
| /// Assumes `op` is a LinalgOp. |
| void mlir::linalg::getDimsOfType(Operation *op, StringRef iteratorTypeName, |
| SmallVectorImpl<AffineExpr> &res) { |
| if (!cast<LinalgOp>(op).iterator_types()) |
| return; |
| |
| unsigned dim = 0; |
| MLIRContext *ctx = op->getContext(); |
| for (auto tn : |
| cast<LinalgOp>(op).iterator_types().getAsValueRange<StringAttr>()) { |
| if (tn == iteratorTypeName) |
| res.push_back(getAffineDimExpr(dim, ctx)); |
| ++dim; |
| } |
| } |
| |
| AffineMap mlir::linalg::extractOrIdentityMap(Optional<AffineMap> maybeMap, |
| unsigned rank, |
| MLIRContext *context) { |
| if (maybeMap) |
| return maybeMap.getValue(); |
| if (rank == 0) |
| return AffineMap::get(context); |
| return AffineMap::getMultiDimIdentityMap(rank, context); |
| } |
| |
| SmallVector<AffineExpr, 4> |
| mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx, |
| MLIRContext *context) { |
| SmallVector<AffineExpr, 4> res; |
| res.reserve(num); |
| for (unsigned i = 0; i < num; ++i) |
| res.push_back(getAffineDimExpr(startIdx++, context)); |
| return res; |
| } |
| |
| template <typename PoolingOp> |
| SmallVector<AffineExpr, 4> |
| mlir::linalg::weightedPoolingInputIndex(PoolingOp op, |
| ArrayRef<AffineExpr> outputDims, |
| ArrayRef<AffineExpr> windowDims) { |
| assert(outputDims.size() == windowDims.size()); |
| SmallVector<AffineExpr, 4> res; |
| res.reserve(outputDims.size()); |
| for (unsigned i = 0, e = outputDims.size(); i < e; ++i) { |
| // TODO: add a level of indirection to linalg.generic. |
| auto expr = op.getStride(i) * outputDims[i] + |
| op.getDilation(i) * windowDims[i] - op.getLowPad(i); |
| res.push_back(expr); |
| } |
| return res; |
| } |
| |
| #define INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(OP_TYPE) \ |
| template SmallVector<AffineExpr, 4> \ |
| mlir::linalg::weightedPoolingInputIndex<OP_TYPE>( \ |
| OP_TYPE op, ArrayRef<AffineExpr> outputDims, \ |
| ArrayRef<AffineExpr> windowDims); |
| |
| INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(ConvOp) |
| INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMaxOp) |
| INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingMinOp) |
| INSTANTIATE_WEIGHTED_POOLING_INPUT_INDEX(PoolingSumOp) |
| |
| SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a, |
| ArrayRef<AffineExpr> b) { |
| auto rangeA = llvm::make_range(a.begin(), a.end()); |
| auto rangeB = llvm::make_range(b.begin(), b.end()); |
| auto concatRanges = llvm::concat<const AffineExpr>(rangeA, rangeB); |
| return llvm::to_vector<4>(concatRanges); |
| } |
| |
| static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { |
| if (auto memref = t.dyn_cast<MemRefType>()) { |
| ss << "view"; |
| for (auto size : memref.getShape()) |
| if (size < 0) |
| ss << "sx"; |
| else |
| ss << size << "x"; |
| appendMangledType(ss, memref.getElementType()); |
| } else if (auto vec = t.dyn_cast<VectorType>()) { |
| ss << "vector"; |
| llvm::interleave( |
| vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; }); |
| appendMangledType(ss, vec.getElementType()); |
| } else if (t.isSignlessIntOrIndexOrFloat()) { |
| ss << t; |
| } else { |
| llvm_unreachable("Invalid type for linalg library name mangling"); |
| } |
| } |
| |
| std::string mlir::linalg::generateLibraryCallName(Operation *op) { |
| assert(isa<LinalgOp>(op)); |
| std::string name(op->getName().getStringRef().str()); |
| name.reserve(128); |
| std::replace(name.begin(), name.end(), '.', '_'); |
| llvm::raw_string_ostream ss(name); |
| ss << "_"; |
| auto types = op->getOperandTypes(); |
| llvm::interleave( |
| types.begin(), types.end(), [&](Type t) { appendMangledType(ss, t); }, |
| [&]() { ss << "_"; }); |
| return ss.str(); |
| } |
| |
| // TODO: Consider making all this boilerplate easy to autogenerate |
| // with Tablegen. This seems a desirable property in the context of |
| // OpInterfaces where a Linalg "named" op **isa** LinalgOp. |
| OpFoldResult ReshapeOp::fold(ArrayRef<Attribute> operands) { |
| if (succeeded(foldMemRefCast(*this))) |
| return getResult(); |
| return foldReshapeOp(*this, operands); |
| } |
| OpFoldResult SliceOp::fold(ArrayRef<Attribute>) { |
| if (succeeded(foldMemRefCast(*this))) |
| return getResult(); |
| return {}; |
| } |
| OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) { |
| return foldReshapeOp(*this, operands); |
| } |
| |
| //===----------------------------------------------------------------------===// |
| // Auto-generated Linalg named ops. |
| //===----------------------------------------------------------------------===// |
| |
| template <typename NamedStructuredOpType> |
| static void buildNamedStructuredOpRegionAndAttributesImpl( |
| OpBuilder &opBuilder, Region ®ion, TypeRange inputTypes, |
| TypeRange outputBufferTypes, TypeRange initTensorTypes, |
| TypeRange resultTypes, |
| std::function<void(unsigned, unsigned)> errorHandler) { |
| // TODO: atm all operands go through getElementTypeOrSelf, |
| // reconsider when we have evidence we need to. |
| SmallVector<Type, 8> argTypes; |
| for (auto containers : {inputTypes, outputBufferTypes, resultTypes}) |
| for (auto t : containers) |
| argTypes.push_back(getElementTypeOrSelf(t)); |
| |
| // RAII. |
| OpBuilder::InsertionGuard guard(opBuilder); |
| Block *body = opBuilder.createBlock(®ion, {}, argTypes); |
| unsigned actual = body->getNumArguments(); |
| unsigned expected = NamedStructuredOpType::getNumRegionArgs(); |
| if (expected != actual) |
| return errorHandler(expected, actual); |
| |
| opBuilder.setInsertionPointToStart(body); |
| mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc()); |
| NamedStructuredOpType::regionBuilder(*body); |
| |
| // indexing_maps is an auto-generated method. |
| |
| // iterator_types is an auto-generated method. |
| } |
| |
| template <typename NamedStructuredOpType> |
| void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder, |
| OperationState &result, |
| TypeRange inputTypes, |
| TypeRange outputBufferTypes, |
| TypeRange initTensorTypes, |
| TypeRange resultTypes) { |
| Region ®ion = *result.addRegion(); |
| buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>( |
| opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes, |
| resultTypes, [&](unsigned expected, unsigned actual) { |
| llvm::errs() << "region expects " << expected << " args, got " |
| << actual; |
| assert(expected != actual && "incorrect number of arguments"); |
| }); |
| } |
| |
| template <typename NamedStructuredOpType> |
| static ParseResult |
| parseNamedStructuredOpRegion(OpAsmParser &parser, Region ®ion, |
| TypeRange inputTypes, TypeRange outputBufferTypes, |
| TypeRange initTensorTypes, TypeRange resultTypes) { |
| ParseResult res = success(); |
| OpBuilder opBuilder(parser.getBuilder().getContext()); |
| buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>( |
| opBuilder, region, inputTypes, outputBufferTypes, initTensorTypes, |
| resultTypes, [&](unsigned expected, unsigned actual) { |
| res = parser.emitError(parser.getCurrentLocation(), |
| llvm::formatv("region expects {0} args, got {1}", |
| expected, actual)); |
| }); |
| return res; |
| } |
| |
| static ParseResult |
| parseNamedStructuredOpResults(OpAsmParser &parser, |
| SmallVectorImpl<Type> &resultTypes) { |
| if (succeeded(parser.parseOptionalArrow())) |
| if (parser.parseTypeList(resultTypes)) |
| return failure(); |
| return success(); |
| } |
| |
| static ParseResult |
| parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result, |
| SmallVectorImpl<Type> &inputTypes, |
| SmallVectorImpl<Type> &outputBufferTypes, |
| SmallVectorImpl<Type> &initTensorTypes) { |
| llvm::SMLoc inputsOperandsLoc, outputBuffersOperandsLoc, |
| initTensorsOperandsLoc; |
| SmallVector<OpAsmParser::OperandType, 4> inputsOperands, |
| outputBuffersOperands, initTensorsOperands; |
| |
| parser.parseOptionalAttrDict(result.attributes); |
| |
| if (succeeded(parser.parseOptionalKeyword("ins"))) { |
| if (parser.parseLParen()) |
| return failure(); |
| |
| inputsOperandsLoc = parser.getCurrentLocation(); |
| if (parser.parseOperandList(inputsOperands) || |
| parser.parseColonTypeList(inputTypes) || parser.parseRParen()) |
| return failure(); |
| } |
| |
| if (succeeded(parser.parseOptionalKeyword("outs"))) { |
| outputBuffersOperandsLoc = parser.getCurrentLocation(); |
| if (parser.parseLParen() || |
| parser.parseOperandList(outputBuffersOperands) || |
| parser.parseColonTypeList(outputBufferTypes) || parser.parseRParen()) |
| return failure(); |
| } |
| if (succeeded(parser.parseOptionalKeyword("init"))) { |
| initTensorsOperandsLoc = parser.getCurrentLocation(); |
| if (parser.parseLParen() || parser.parseOperandList(initTensorsOperands) || |
| parser.parseColonTypeList(initTensorTypes) || parser.parseRParen()) |
| return failure(); |
| } |
| |
| if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, |
| result.operands) || |
| parser.resolveOperands(outputBuffersOperands, outputBufferTypes, |
| outputBuffersOperandsLoc, result.operands) || |
| parser.resolveOperands(initTensorsOperands, initTensorTypes, |
| initTensorsOperandsLoc, result.operands)) |
| return failure(); |
| |
| result.addAttribute("operand_segment_sizes", |
| parser.getBuilder().getI32VectorAttr( |
| {static_cast<int32_t>(inputsOperands.size()), |
| static_cast<int32_t>(outputBuffersOperands.size()), |
| static_cast<int32_t>(initTensorsOperands.size())})); |
| return success(); |
| } |
| |
| template <typename NamedStructuredOpType> |
| static ParseResult parseNamedStructuredOp(OpAsmParser &parser, |
| OperationState &result) { |
| SmallVector<Type, 1> inputTypes, outputBufferTypes, initTensorTypes; |
| if (parseCommonStructuredOpParts(parser, result, inputTypes, |
| outputBufferTypes, initTensorTypes)) |
| return failure(); |
| |
| // TODO: consider merging results parsing into region parsing. |
| // Need to wait for declarative assembly resolution to decide. |
| SmallVector<Type, 1> outputTensorsTypes; |
| if (parseNamedStructuredOpResults(parser, outputTensorsTypes)) |
| return failure(); |
| result.addTypes(outputTensorsTypes); |
| |
| std::unique_ptr<Region> region = std::make_unique<Region>(); |
| if (parseNamedStructuredOpRegion<NamedStructuredOpType>( |
| parser, *region, inputTypes, outputBufferTypes, initTensorTypes, |
| outputTensorsTypes)) |
| return failure(); |
| result.addRegion(std::move(region)); |
| |
| return success(); |
| } |
| |
| static void printNamedStructuredOpResults(OpAsmPrinter &p, |
| TypeRange resultTypes) { |
| if (resultTypes.empty()) |
| return; |
| p.printOptionalArrowTypeList(resultTypes); |
| } |
| |
| template <typename NamedStructuredOpType> |
| static void printCommonStructuredOpParts(OpAsmPrinter &p, |
| NamedStructuredOpType op) { |
| if (!op.inputs().empty()) |
| p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")"; |
| if (!op.output_buffers().empty()) |
| p << " outs(" << op.output_buffers() << " : " |
| << op.output_buffers().getTypes() << ")"; |
| if (!op.init_tensors().empty()) |
| p << " init(" << op.init_tensors() << " : " << op.init_tensors().getTypes() |
| << ") "; |
| } |
| |
| template <typename NamedStructuredOpType> |
| static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) { |
| p << op.getOperationName(); |
| p.printOptionalAttrDict(op.getAttrs(), |
| /*elidedAttrs=*/{"operand_segment_sizes"}); |
| |
| // Printing is shared with generic ops, except for the region and |
| // attributes. |
| printCommonStructuredOpParts(p, op); |
| |
| // Results printing. |
| printNamedStructuredOpResults(p, op.result_tensors().getTypes()); |
| |
| // Region is elided. |
| } |
| |
| template <typename NamedStructuredOpType> |
| static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) { |
| return verifyGenericOp<NamedStructuredOpType>(op); |
| } |
| |
| namespace { |
| struct EraseDeadLinalgOp : public RewritePattern { |
| EraseDeadLinalgOp(PatternBenefit benefit = 1) |
| : RewritePattern(benefit, MatchAnyOpTypeTag()) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| auto linalgOp = dyn_cast<LinalgOp>(op); |
| if (!linalgOp) |
| return failure(); |
| for (Value v : linalgOp.getInputsAndOutputBuffers()) { |
| // Linalg "inputs" may be either tensor or memref type. |
| // tensor<0xelt_type> is a convention that may not always mean |
| // "0 iterations". Only erase in cases we see memref<...x0x...>. |
| auto mt = v.getType().dyn_cast<MemRefType>(); |
| if (!mt) |
| continue; |
| if (llvm::is_contained(mt.getShape(), 0)) { |
| rewriter.eraseOp(linalgOp); |
| return success(); |
| } |
| } |
| return failure(); |
| } |
| }; |
| |
| struct FoldTensorCastOp : public RewritePattern { |
| FoldTensorCastOp(PatternBenefit benefit = 1) |
| : RewritePattern(benefit, MatchAnyOpTypeTag()) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| auto linalgOp = dyn_cast<LinalgOp>(op); |
| if (!linalgOp) |
| return failure(); |
| |
| // If no operand comes from a TensorCastOp and can be folded then fail. |
| bool hasTensorCastOperand = |
| llvm::any_of(linalgOp.getShapedOperands(), [&](Value v) { |
| if (v.isa<BlockArgument>()) |
| return false; |
| auto castOp = v.getDefiningOp<TensorCastOp>(); |
| return castOp && canFoldIntoConsumerOp(castOp); |
| }); |
| if (!hasTensorCastOperand) |
| return failure(); |
| |
| SmallVector<Type, 4> newResultTypes; |
| newResultTypes.reserve(op->getNumResults()); |
| SmallVector<Value, 4> newOperands; |
| newOperands.reserve(op->getNumOperands()); |
| // Inputs may fold. |
| for (Value v : linalgOp.getInputs()) { |
| auto tensorCastOp = v.getDefiningOp<TensorCastOp>(); |
| newOperands.push_back( |
| canFoldIntoConsumerOp(tensorCastOp) ? tensorCastOp.source() : v); |
| } |
| // Output buffers are memrefs, they don't fold. |
| newOperands.append(linalgOp.getOutputBuffers().begin(), |
| linalgOp.getOutputBuffers().end()); |
| // Init tensors may fold, in which case the resultType must also change. |
| for (Value v : linalgOp.getInitTensors()) { |
| auto tensorCastOp = v.getDefiningOp<TensorCastOp>(); |
| bool fold = canFoldIntoConsumerOp(tensorCastOp); |
| newOperands.push_back(fold ? tensorCastOp.getOperand() : v); |
| newResultTypes.push_back(newOperands.back().getType()); |
| } |
| auto extraOperands = linalgOp.getAssumedNonShapedOperands(); |
| newOperands.append(extraOperands.begin(), extraOperands.end()); |
| // Clone op. |
| Operation *newOp = |
| linalgOp.clone(rewriter, op->getLoc(), newResultTypes, newOperands); |
| rewriter.replaceOp(op, newOp->getResults()); |
| |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| namespace { |
| // Deduplicate redundant args of a linalg op. |
| // An arg is redundant if it has the same Value and indexing map as another. |
| struct DeduplicateInputs : public RewritePattern { |
| DeduplicateInputs(PatternBenefit benefit = 1) |
| : RewritePattern(benefit, MatchAnyOpTypeTag()) {} |
| |
| LogicalResult matchAndRewrite(Operation *op, |
| PatternRewriter &rewriter) const override { |
| // This pattern reduces the number of arguments of an op, which breaks |
| // the invariants of semantically charged named ops. |
| if (!isa<GenericOp, IndexedGenericOp>(op)) |
| return failure(); |
| auto linalgOp = cast<LinalgOp>(op); |
| |
| // Associate each input to an equivalent "canonical" input that has the same |
| // Value and indexing map. |
| // |
| // In the non-duplicate case, input `i` will have canonical input `i`. But |
| // in the case of duplicated inputs, the canonical input could be some other |
| // input `< i`. That is, a later input will have some earlier input as its |
| // canonical input. |
| llvm::SmallDenseMap<std::pair<Value, AffineMap>, int> canonicalInput; |
| // For later remapping tasks like deduplicating payload block arguments, |
| // having a simple "inputIndex -> canonicalInputIndex" integer mapping is |
| // convenient. |
| SmallVector<int, 6> canonicalInputIndices; |
| for (int i = 0, e = linalgOp.getNumInputs(); i != e; i++) { |
| Value input = linalgOp.getInput(i); |
| AffineMap indexingMap = linalgOp.getInputIndexingMap(i); |
| // STL-like maps have a convenient behavior for our use case here. In the |
| // case of duplicate keys, the insertion is rejected, and the returned |
| // iterator gives access to the value already in the map. |
| auto pair = canonicalInput.insert({{input, indexingMap}, i}); |
| canonicalInputIndices.push_back(pair.first->second); |
| } |
| |
| // If there are no duplicate args, then bail out. |
| if (canonicalInput.size() == linalgOp.getNumInputs()) |
| return failure(); |
| |
| // The operands for the newly canonicalized op. |
| SmallVector<Value, 6> newOperands; |
| for (auto v : llvm::enumerate(linalgOp.getInputs())) |
| if (canonicalInputIndices[v.index()] == static_cast<int>(v.index())) |
| newOperands.push_back(v.value()); |
| llvm::append_range(newOperands, linalgOp.getOutputBuffers()); |
| llvm::append_range(newOperands, linalgOp.getInitTensors()); |
| llvm::append_range(newOperands, linalgOp.getAssumedNonShapedOperands()); |
| |
| // Clone the old op with new operands. |
| Operation *newOp = linalgOp.clone(rewriter, op->getLoc(), |
| op->getResultTypes(), newOperands); |
| auto newLinalgOp = cast<LinalgOp>(newOp); |
| |
| // Repair the indexing maps by filtering out the ones that have been |
| // eliminated. |
| SmallVector<AffineMap, 6> newIndexingMaps; |
| for (int i = 0, e = newLinalgOp.getNumInputs(); i != e; i++) |
| if (canonicalInputIndices[i] == i) |
| newIndexingMaps.push_back(newLinalgOp.getIndexingMap(i)); |
| for (int i = 0, e = newLinalgOp.getNumOutputs(); i != e; i++) |
| newIndexingMaps.push_back(newLinalgOp.getOutputIndexingMap(i)); |
| newOp->setAttr("indexing_maps", |
| rewriter.getAffineMapArrayAttr(newIndexingMaps)); |
| |
| // Set the number of inputs to the new value. The `clone` call above kept |
| // the value from the original op. |
| newLinalgOp.setNumInputs(canonicalInput.size()); |
| |
| // linalg.indexed_generic payloads have additional arguments prepended to |
| // the block arg list. The number of such args is one per dimension of the |
| // iteration space. |
| int bbArgBaseOffset = 0; |
| if (isa<IndexedGenericOp>(op)) |
| bbArgBaseOffset = newIndexingMaps[0].getNumInputs(); |
| |
| // Repair the payload entry block by RAUW'ing redundant arguments and |
| // erasing them. |
| Block &payload = newOp->getRegion(0).front(); |
| for (int i = 0, e = linalgOp.getNumInputs(); i < e; i++) { |
| // Iterate in reverse, so that we erase later args first, preventing the |
| // argument list from shifting unexpectedly and invalidating all our |
| // indices. |
| int reversed = e - i - 1; |
| int canonicalIndex = canonicalInputIndices[reversed]; |
| if (canonicalInputIndices[reversed] == reversed) |
| continue; |
| payload.getArgument(bbArgBaseOffset + reversed) |
| .replaceAllUsesWith( |
| payload.getArgument(bbArgBaseOffset + canonicalIndex)); |
| payload.eraseArgument(bbArgBaseOffset + reversed); |
| } |
| |
| rewriter.replaceOp(op, newOp->getResults()); |
| return success(); |
| } |
| }; |
| } // namespace |
| |
| #define CANONICALIZERS_AND_FOLDERS(XXX) \ |
| void XXX::getCanonicalizationPatterns(OwningRewritePatternList &results, \ |
| MLIRContext *context) { \ |
| results.insert<EraseDeadLinalgOp>(); \ |
| results.insert<FoldTensorCastOp>(); \ |
| results.insert<DeduplicateInputs>(); \ |
| } \ |
| \ |
| LogicalResult XXX::fold(ArrayRef<Attribute>, \ |
| SmallVectorImpl<OpFoldResult> &) { \ |
| return foldMemRefCast(*this); \ |
| } |
| |
| CANONICALIZERS_AND_FOLDERS(ConvOp) |
| CANONICALIZERS_AND_FOLDERS(PoolingMaxOp) |
| CANONICALIZERS_AND_FOLDERS(PoolingMinOp) |
| CANONICALIZERS_AND_FOLDERS(PoolingSumOp) |
| CANONICALIZERS_AND_FOLDERS(CopyOp) |
| CANONICALIZERS_AND_FOLDERS(FillOp) |
| CANONICALIZERS_AND_FOLDERS(GenericOp) |
| CANONICALIZERS_AND_FOLDERS(IndexedGenericOp) |
| |
| // All named ops canonicalizers and folders are auto-generated in the |
| // .cpp.inc. |