Skip to content

Commit

Permalink
[mlir][Arith] Generalize and improve -int-range-optimizations (llvm#9…
Browse files Browse the repository at this point in the history
…4712)

When the integer range analysis was first develop, a pass that did
integer range-based constant folding was developed and used as a test
pass. There was an intent to add such a folding to SCCP, but that hasn't
happened.

Meanwhile, -int-range-optimizations was added to the arith dialect's
transformations. The cmpi simplification in that pass is a strict subset
of the constant folding that lived in
-test-int-range-inference.

This commit moves the former test pass into -int-range-optimizaitons,
subsuming its previous contents. It also adds an optimization from
rocMLIR where `rem{s,u}i` operations that are noops are replaced by
their left operands.
  • Loading branch information
krzysz00 authored Jun 10, 2024
1 parent 3e39328 commit 4722911
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 256 deletions.
9 changes: 7 additions & 2 deletions mlir/include/mlir/Dialect/Arith/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ def ArithIntRangeOpts : Pass<"int-range-optimizations"> {
let summary = "Do optimizations based on integer range analysis";
let description = [{
This pass runs integer range analysis and apllies optimizations based on its
results. e.g. replace arith.cmpi with const if it can be inferred from
args ranges.
results. It replaces operations with known-constant results with said constants,
rewrites `(0 <= %x < D) mod D` to `%x`.
}];
// Explicitly depend on "arith" because this pass could create operations in
// `arith` out of thin air in some cases.
let dependentDialects = [
"::mlir::arith::ArithDialect"
];
}

def ArithEmulateUnsupportedFloats : Pass<"arith-emulate-unsupported-floats"> {
Expand Down
244 changes: 128 additions & 116 deletions mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@

#include <utility>

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"

#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Analysis/DataFlow/IntegerRangeAnalysis.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir::arith {
Expand All @@ -24,88 +30,50 @@ using namespace mlir;
using namespace mlir::arith;
using namespace mlir::dataflow;

/// Returns true if 2 integer ranges have intersection.
static bool intersects(const ConstantIntRanges &lhs,
const ConstantIntRanges &rhs) {
return !((lhs.smax().slt(rhs.smin()) || lhs.smin().sgt(rhs.smax())) &&
(lhs.umax().ult(rhs.umin()) || lhs.umin().ugt(rhs.umax())));
static std::optional<APInt> getMaybeConstantValue(DataFlowSolver &solver,
Value value) {
auto *maybeInferredRange =
solver.lookupState<IntegerValueRangeLattice>(value);
if (!maybeInferredRange || maybeInferredRange->getValue().isUninitialized())
return std::nullopt;
const ConstantIntRanges &inferredRange =
maybeInferredRange->getValue().getValue();
return inferredRange.getConstantValue();
}

static FailureOr<bool> handleEq(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return false;

return failure();
}

static FailureOr<bool> handleNe(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (!intersects(lhs, rhs))
return true;

return failure();
}

static FailureOr<bool> handleSlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().slt(rhs.smin()))
return true;

if (lhs.smin().sge(rhs.smax()))
return false;

return failure();
}

static FailureOr<bool> handleSle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.smax().sle(rhs.smin()))
return true;

if (lhs.smin().sgt(rhs.smax()))
return false;

return failure();
}

static FailureOr<bool> handleSgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSlt(std::move(rhs), std::move(lhs));
}

static FailureOr<bool> handleSge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleSle(std::move(rhs), std::move(lhs));
}

static FailureOr<bool> handleUlt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ult(rhs.umin()))
return true;

if (lhs.umin().uge(rhs.umax()))
return false;

return failure();
}

static FailureOr<bool> handleUle(ConstantIntRanges lhs, ConstantIntRanges rhs) {
if (lhs.umax().ule(rhs.umin()))
return true;

if (lhs.umin().ugt(rhs.umax()))
return false;

return failure();
}

static FailureOr<bool> handleUgt(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUlt(std::move(rhs), std::move(lhs));
}

static FailureOr<bool> handleUge(ConstantIntRanges lhs, ConstantIntRanges rhs) {
return handleUle(std::move(rhs), std::move(lhs));
/// Patterned after SCCP
static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver,
PatternRewriter &rewriter,
Value value) {
if (value.use_empty())
return failure();
std::optional<APInt> maybeConstValue = getMaybeConstantValue(solver, value);
if (!maybeConstValue.has_value())
return failure();

Operation *maybeDefiningOp = value.getDefiningOp();
Dialect *valueDialect =
maybeDefiningOp ? maybeDefiningOp->getDialect()
: value.getParentRegion()->getParentOp()->getDialect();
Attribute constAttr =
rewriter.getIntegerAttr(value.getType(), *maybeConstValue);
Operation *constOp = valueDialect->materializeConstant(
rewriter, constAttr, value.getType(), value.getLoc());
// Fall back to arith.constant if the dialect materializer doesn't know what
// to do with an integer constant.
if (!constOp)
constOp = rewriter.getContext()
->getLoadedDialect<ArithDialect>()
->materializeConstant(rewriter, constAttr, value.getType(),
value.getLoc());
if (!constOp)
return failure();

rewriter.replaceAllUsesWith(value, constOp->getResult(0));
return success();
}

namespace {
/// This class listens on IR transformations performed during a pass relying on
/// information from a `DataflowSolver`. It erases state associated with the
/// erased operation and its results from the `DataFlowSolver` so that Patterns
/// do not accidentally query old state information for newly created Ops.
class DataFlowListener : public RewriterBase::Listener {
public:
DataFlowListener(DataFlowSolver &s) : s(s) {}
Expand All @@ -120,52 +88,95 @@ class DataFlowListener : public RewriterBase::Listener {
DataFlowSolver &s;
};

struct ConvertCmpOp : public OpRewritePattern<arith::CmpIOp> {
/// Rewrite any results of `op` that were inferred to be constant integers to
/// and replace their uses with that constant. Return success() if all results
/// where thus replaced and the operation is erased. Also replace any block
/// arguments with their constant values.
struct MaterializeKnownConstantValues : public RewritePattern {
MaterializeKnownConstantValues(MLIRContext *context, DataFlowSolver &s)
: RewritePattern(Pattern::MatchAnyOpTypeTag(), /*benefit=*/1, context),
solver(s) {}

LogicalResult match(Operation *op) const override {
if (matchPattern(op, m_Constant()))
return failure();

ConvertCmpOp(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<arith::CmpIOp>(context), solver(s) {}
auto needsReplacing = [&](Value v) {
return getMaybeConstantValue(solver, v).has_value() && !v.use_empty();
};
bool hasConstantResults = llvm::any_of(op->getResults(), needsReplacing);
if (op->getNumRegions() == 0)
return success(hasConstantResults);
bool hasConstantRegionArgs = false;
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
hasConstantRegionArgs |=
llvm::any_of(block.getArguments(), needsReplacing);
}
}
return success(hasConstantResults || hasConstantRegionArgs);
}

LogicalResult matchAndRewrite(arith::CmpIOp op,
void rewrite(Operation *op, PatternRewriter &rewriter) const override {
bool replacedAll = (op->getNumResults() != 0);
for (Value v : op->getResults())
replacedAll &=
(succeeded(maybeReplaceWithConstant(solver, rewriter, v)) ||
v.use_empty());
if (replacedAll && isOpTriviallyDead(op)) {
rewriter.eraseOp(op);
return;
}

PatternRewriter::InsertionGuard guard(rewriter);
for (Region &region : op->getRegions()) {
for (Block &block : region.getBlocks()) {
rewriter.setInsertionPointToStart(&block);
for (BlockArgument &arg : block.getArguments()) {
(void)maybeReplaceWithConstant(solver, rewriter, arg);
}
}
}
}

private:
DataFlowSolver &solver;
};

template <typename RemOp>
struct DeleteTrivialRem : public OpRewritePattern<RemOp> {
DeleteTrivialRem(MLIRContext *context, DataFlowSolver &s)
: OpRewritePattern<RemOp>(context), solver(s) {}

LogicalResult matchAndRewrite(RemOp op,
PatternRewriter &rewriter) const override {
auto *lhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getLhs());
if (!lhsResult || lhsResult->getValue().isUninitialized())
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
auto maybeModulus = getConstantIntValue(rhs);
if (!maybeModulus.has_value())
return failure();

auto *rhsResult =
solver.lookupState<dataflow::IntegerValueRangeLattice>(op.getRhs());
if (!rhsResult || rhsResult->getValue().isUninitialized())
int64_t modulus = *maybeModulus;
if (modulus <= 0)
return failure();

using HandlerFunc =
FailureOr<bool> (*)(ConstantIntRanges, ConstantIntRanges);
std::array<HandlerFunc, arith::getMaxEnumValForCmpIPredicate() + 1>
handlers{};
using Pred = arith::CmpIPredicate;
handlers[static_cast<size_t>(Pred::eq)] = &handleEq;
handlers[static_cast<size_t>(Pred::ne)] = &handleNe;
handlers[static_cast<size_t>(Pred::slt)] = &handleSlt;
handlers[static_cast<size_t>(Pred::sle)] = &handleSle;
handlers[static_cast<size_t>(Pred::sgt)] = &handleSgt;
handlers[static_cast<size_t>(Pred::sge)] = &handleSge;
handlers[static_cast<size_t>(Pred::ult)] = &handleUlt;
handlers[static_cast<size_t>(Pred::ule)] = &handleUle;
handlers[static_cast<size_t>(Pred::ugt)] = &handleUgt;
handlers[static_cast<size_t>(Pred::uge)] = &handleUge;

HandlerFunc handler = handlers[static_cast<size_t>(op.getPredicate())];
if (!handler)
auto *maybeLhsRange = solver.lookupState<IntegerValueRangeLattice>(lhs);
if (!maybeLhsRange || maybeLhsRange->getValue().isUninitialized())
return failure();

ConstantIntRanges lhsValue = lhsResult->getValue().getValue();
ConstantIntRanges rhsValue = rhsResult->getValue().getValue();
FailureOr<bool> result = handler(lhsValue, rhsValue);

if (failed(result))
const ConstantIntRanges &lhsRange = maybeLhsRange->getValue().getValue();
const APInt &min = isa<RemUIOp>(op) ? lhsRange.umin() : lhsRange.smin();
const APInt &max = isa<RemUIOp>(op) ? lhsRange.umax() : lhsRange.smax();
// The minima and maxima here are given as closed ranges, we must be
// strictly less than the modulus.
if (min.isNegative() || min.uge(modulus))
return failure();
if (max.isNegative() || max.uge(modulus))
return failure();
if (!min.ule(max))
return failure();

rewriter.replaceOpWithNewOp<arith::ConstantIntOp>(
op, static_cast<int64_t>(*result), /*width*/ 1);
// With all those conditions out of the way, we know thas this invocation of
// a remainder is a noop because the input is strictly within the range
// [0, modulus), so get rid of it.
rewriter.replaceOp(op, ValueRange{lhs});
return success();
}

Expand Down Expand Up @@ -201,7 +212,8 @@ struct IntRangeOptimizationsPass

void mlir::arith::populateIntRangeOptimizationsPatterns(
RewritePatternSet &patterns, DataFlowSolver &solver) {
patterns.add<ConvertCmpOp>(patterns.getContext(), solver);
patterns.add<MaterializeKnownConstantValues, DeleteTrivialRem<RemSIOp>,
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
}

std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Arith/int-range-interface.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s

// CHECK-LABEL: func @add_min_max
// CHECK: %[[c3:.*]] = arith.constant 3 : index
Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Arith/int-range-opts.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,39 @@ func.func @test() -> i8 {
return %1: i8
}

// -----

// CHECK-LABEL: func @trivial_rem
// CHECK: [[val:%.+]] = test.with_bounds
// CHECK: return [[val]]
func.func @trivial_rem() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 0 : ui8, umax = 63 : ui8, smin = 0 : si8, smax = 63 : si8 } : i8
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}

// -----

// CHECK-LABEL: func @non_const_rhs
// CHECK: [[mod:%.+]] = arith.remui
// CHECK: return [[mod]]
func.func @non_const_rhs() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 0 : ui8, umax = 2 : ui8, smin = 0 : si8, smax = 2 : si8 } : i8
%rhs = test.with_bounds { umin = 63 : ui8, umax = 64 : ui8, smin = 63 : si8, smax = 64 : si8 } : i8
%mod = arith.remui %val, %rhs : i8
return %mod : i8
}

// -----

// CHECK-LABEL: func @wraps
// CHECK: [[mod:%.+]] = arith.remsi
// CHECK: return [[mod]]
func.func @wraps() -> i8 {
%c64 = arith.constant 64 : i8
%val = test.with_bounds { umin = 63 : ui8, umax = 65 : ui8, smin = 63 : si8, smax = 65 : si8 } : i8
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/GPU/int-range-interface.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference -split-input-file %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations -split-input-file %s | FileCheck %s

// CHECK-LABEL: func @launch_func
func.func @launch_func(%arg0 : index) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/Index/int-range-inference.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference -canonicalize %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s

// Most operations are covered by the `arith` tests, which use the same code
// Here, we add a few tests to ensure the "index can be 32- or 64-bit" handling
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: mlir-opt -test-int-range-inference %s | FileCheck %s
// RUN: mlir-opt -int-range-optimizations %s | FileCheck %s

// CHECK-LABEL: func @constant
// CHECK: %[[cst:.*]] = "test.constant"() <{value = 3 : index}
Expand Down Expand Up @@ -103,13 +103,11 @@ func.func @func_args_unbound(%arg0 : index) -> index {

// CHECK-LABEL: func @propagate_across_while_loop_false()
func.func @propagate_across_while_loop_false() -> index {
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
%false = arith.constant false
// CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%false) %0 : index
} do {
^bb0(%i1: index):
Expand All @@ -122,12 +120,10 @@ func.func @propagate_across_while_loop_false() -> index {

// CHECK-LABEL: func @propagate_across_while_loop
func.func @propagate_across_while_loop(%arg0 : i1) -> index {
// CHECK-DAG: %[[C0:.*]] = "test.constant"() <{value = 0
// CHECK-DAG: %[[C1:.*]] = "test.constant"() <{value = 1
// CHECK: %[[C1:.*]] = "test.constant"() <{value = 1
%0 = test.with_bounds { umin = 0 : index, umax = 0 : index,
smin = 0 : index, smax = 0 : index } : index
%1 = scf.while : () -> index {
// CHECK: scf.condition(%{{.*}}) %[[C0]]
scf.condition(%arg0) %0 : index
} do {
^bb0(%i1: index):
Expand Down
Loading

0 comments on commit 4722911

Please sign in to comment.