Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 6, 2025
1 parent fa93aab commit ab6f030
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
1 change: 0 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/RemovalUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ mlir::enzyme::CacheInfo::merge(mlir::enzyme::CacheInfo other) {
other.initOp->erase();
}

enzyme::PushOp newPushOp = pushOp;
other.pushOp->erase();

enzyme::PopOp newPopOp;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/RemovalUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ struct CacheInfo {

Value pushedValue() { return pushOp.getValue(); }
Type cachedType() {
return initOp.getResult().getType().cast<enzyme::CacheType>().getType();
return cast<enzyme::CacheType>(initOp.getResult().getType()).getType();
}

// Pushed values must be the same
Expand Down
8 changes: 7 additions & 1 deletion enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,18 @@ struct RemoveUnusedEnzymeOpsPass

applyPatterns(op);

bool failed = false;
op->walk([&](FunctionOpInterface func) {
func->walk([&](enzyme::EnzymeOpsRemoverOpInterface iface) {
iface.removeEnzymeOps();
auto result = iface.removeEnzymeOps();
if (!result.succeeded())
failed = true;
});
});

if (failed)
return signalPassFailure();

applyPatterns(op);
}
};
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ForwardMode/batched_scalar.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ module {
// CHECK-NEXT: %[[s1:.+]] = "enzyme.broadcast"(%[[arg0]]) <{shape = array<i64: 2>}> : (f64) -> tensor<2xf64>
// CHECK-NEXT: %[[i1:.+]] = arith.mulf %[[arg1]], %[[s1]] : tensor<2xf64>
// CHECK-NEXT: %[[i2:.+]] = arith.addf %[[i0]], %[[i1]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]] : tensor<2xf64>
// CHECK-NEXT: %[[i3:.+]] = arith.mulf %[[arg0]], %[[arg0]]
// CHECK-NEXT: return %[[i2]] : tensor<2xf64>
// CHECK-NEXT: }

0 comments on commit ab6f030

Please sign in to comment.