blob: 43c670a8582e29b97aa33e9ce195e52c7d78416b [file] [log] [blame] [edit]
include "mlir/Dialect/Shape/IR/ShapeOps.td"
include "mlir/Dialect/StandardOps/IR/Ops.td"
def AllInputShapesEq : Constraint<CPred< [{
llvm::all_of($0, [&](mlir::Value val) {
return $0[0] == val;
})
}]>>;
def HasSingleElement : Constraint<CPred< [{
$0.size() == 1
}]>>;
// Canonicalization patterns.
def AssumingAllOneOp : Pat<(Shape_AssumingAllOp $args),
(replaceWithValue $args),
[(HasSingleElement $args)]>;
def CstrBroadcastableEqOps : Pat<(Shape_CstrBroadcastableOp:$op $x, $x),
(Shape_ConstWitnessOp ConstBoolAttrTrue)>;
def CstrEqEqOps : Pat<(Shape_CstrEqOp:$op $shapes),
(Shape_ConstWitnessOp ConstBoolAttrTrue),
[(AllInputShapesEq $shapes)]>;
def IndexToSizeToIndexCanonicalization : Pat<
(Shape_SizeToIndexOp (Shape_IndexToSizeOp $arg)),
(replaceWithValue $arg)>;
def SizeToIndexToSizeCanonicalization : Pat<
(Shape_IndexToSizeOp (Shape_SizeToIndexOp $arg)),
(replaceWithValue $arg)>;
def TensorCastConstShape : Pat <
(TensorCastOp (Shape_ConstShapeOp:$c $ty)), (replaceWithValue $c)>;