From 20e583834ee8e8dc6e4d0c43d2eb86cfe9428f58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Mon, 23 Oct 2023 12:36:20 -0700 Subject: [PATCH] [Mosaic] apply_vector_layout C++: Be consistent about using "Not implemented" as a prefix for error messages I want to rely on this in the Python bindings to raise `NotImplementedError` exceptions. PiperOrigin-RevId: 575897758 --- .../tpu/transforms/apply_vector_layout.cc | 76 +++++++++++-------- 1 file changed, 44 insertions(+), 32 deletions(-) diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 5685f56cb083..2b179542da5e 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -64,10 +64,6 @@ // TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more // consistent about this for layout null checks in rules. -#define NYI(msg) \ - op->emitOpError("not implemented: " msg); \ - return failure(); - namespace mlir::tpu { // TODO(tlongeri): Maybe just roll our own multi-dimensional array instead of // using XLA's? There's too much glue for going from/to ArrayRef. @@ -354,7 +350,8 @@ FailureOr appendConstant(RewriteContext &ctx, Block &entry_block = ctx.func.getBody().front(); auto value_ty = cast(value.getType()); if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) { - return ctx.func.emitOpError("Only 32-bit constants supported"); + return ctx.func.emitOpError( + "Not implemented: Only 32-bit constants supported"); } if (ctx.func->getAttr("scratch_operands")) { return ctx.func.emitOpError( @@ -514,7 +511,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, if (!llvm::all_of(layouts_in, [&](const Layout &l) { return l->generalizes(layout_out, out_ty.getShape(), ctx.target_shape); })) { - return op.emitOpError("Incompatible layouts in elementwise operation"); + return op.emitOpError( + "Not implemented: Incompatible layouts in elementwise operation"); } const unsigned num_operands = op.getNumOperands(); SmallVector> in_vreg_arrays; @@ -585,7 +583,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); auto result_ty = cast(op.getResult().getType()); if (layout_out.bitwidth() != 32) { - return op.emitOpError("Only extensions to 32-bit supported"); + return op.emitOpError( + "Not implemented: Only extensions to 32-bit supported"); } FAILUREOR_ASSIGN_OR_RETURN(const xla::Array input_vregs, disassemble(ctx, builder, layout_in, op.getIn())); @@ -600,7 +599,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, switch (layout_in.implicit_dim()) { case VectorLayout::ImplicitDim::kNone: { if (layout_in.tiling() != layout_out.tiling()) { - return op.emitOpError("Changing tiling during extension"); + return op.emitOpError( + "Not implemented: Changing tiling during the cast"); } auto tiling = layout_in.tiling(); if (ctx.target_shape[0] % tiling[0] != 0 || @@ -648,7 +648,8 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, auto extf_op = cast(op); if (layouts_in.front()->bitwidth() != 16 || layouts_out.front()->bitwidth() != 32) { - return op.emitOpError("Only 16-bit to 32-bit conversion supported"); + return op.emitOpError( + "Not implemented: Only 16-bit to 32-bit conversion supported"); } return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(), *layouts_out.front()); @@ -677,7 +678,7 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, xla::Array output_vregs( layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape)); if (layout_in.bitwidth() != 32) { - return op.emitOpError("Only 32-bit truncation supported"); + return op.emitOpError("Not implemented: Only 32-bit truncation supported"); } FAILUREOR_ASSIGN_OR_RETURN( VectorType res_vreg_ty, @@ -1257,8 +1258,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, if (!layout.hasNaturalTopology(ctx.target_shape) || layout.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError( - "Only native tiling with offset (0, 0) is supported when " - "concatenation along tiling dims."); + "Not implemented: Only native tiling with offset (0, 0) is supported " + "when concatenation along tiling dims."); } // Check if shapes of src and res are aligned to native tiling. auto check_aligned = [&](const VectorType &vty) { @@ -1274,8 +1275,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, } if (!is_aligned) { return op.emitOpError( - "Only aligned shapes are supported when concatenation along tiling " - "dims"); + "Not implemented: Only aligned shapes are supported when " + "concatenation along tiling dims"); } } @@ -1582,12 +1583,14 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, AffineMap load_map; arith::ConstantOp padding; if (offsets[1] == std::nullopt) { - return op.emitOpError("Load replicated along lanes is unsupported"); + return op.emitOpError( + "Not implemented: Load replicated along lanes is unsupported"); } if (offsets[0] == std::nullopt) { if (ss != 1) { return op.emitOpError( - "Sublane-replicated load with size > 1 is unsupported"); + "Not implemented: Sublane-replicated load with size > 1 is " + "unsupported"); } if (!layout_out.hasNativeTiling(ctx.target_shape)) { return op.emitOpError("Not implemented"); @@ -1692,7 +1695,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, getNativeVregType(vty.getElementType(), ctx.target_shape)); if (value.isSplat()) { if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { - return op.emitOpError("Non-replicated splat constants"); + return op.emitOpError( + "Not implemented: Non-replicated splat constants"); } auto new_value = DenseElementsAttr::get(target_vty, value.getSplatValue()); @@ -1708,7 +1712,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, } // !value.isSplat() if (getTypeBitwidth(vty.getElementType()) != 32) { - return op.emitOpError("Only 32-bit non-splat constants are supported"); + return op.emitOpError( + "Not implemented: Only 32-bit non-splat constants are supported"); } FAILUREOR_ASSIGN_OR_RETURN(const BlockArgument ref, appendConstant(ctx, value)); @@ -1722,7 +1727,7 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, {VectorLayout(/*bitwidth=*/32, /*offsets=*/{0, 0}, /*tiling=*/ctx.target_shape)}); } - return op.emitOpError("Unsupported arith.const type: ") + return op.emitOpError("Not implemented: Unsupported arith.const type: ") << op.getResult(0).getType(); } @@ -1959,7 +1964,7 @@ LogicalResult vector_contract_rule(RewriteContext &ctx, Operation &op, if (indexing_maps != matmul_indexing_maps && indexing_maps != matmul_indexing_maps_transposed) { return vector_contract_op->emitOpError( - "Non-matmul or unsupported indexing_maps"); + "Not implemented: Non-matmul or unsupported indexing_maps"); } const bool transpose_rhs = indexing_maps == matmul_indexing_maps_transposed; const ArrayAttr matmul_iterator_types = @@ -2125,7 +2130,8 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, "Not implemented: unsupported kind"); } if (val != neutral.getValueAsDouble()) { - return multi_reduction_op.emitOpError("Only neutral accumulator supported"); + return multi_reduction_op.emitOpError( + "Not implemented: Only neutral accumulator supported"); } if (src_layout.implicit_dim() == VectorLayout::ImplicitDim::kNone && @@ -2133,7 +2139,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, auto [sublane_offset, lane_offset] = src_layout.offsets(); if (dim < 0) { return multi_reduction_op.emitOpError( - "Negative reduction dimension unsupported"); + "Not implemented: Negative reduction dimension unsupported"); } int64_t vdim; Direction reduce_over; @@ -2252,7 +2258,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, multi_reduction_op->erase(); return success(); } - return multi_reduction_op->emitOpError("Unsupported layout: ") << src_layout; + return multi_reduction_op->emitOpError( + "Not implemented: Unsupported layout: ") + << src_layout; } LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op, @@ -2826,7 +2834,7 @@ FailureOr> disassemble(RewriteContext &ctx, val.getLoc(), SmallVector(num_vectors, vreg_ty), val); return XlaArrayFromShapeAndValues(layout_shape, u->getResults()); } - return op->emitOpError("unimplemented: ") << val; + return op->emitOpError("Not implemented: ") << val; } // Assembles a destination tile using partial data from rotated vregs using a @@ -3337,7 +3345,8 @@ FailureOr relayout(RewriteContext &ctx, OpBuilder &builder, Value v, const auto &tiling = src.tiling(); // TODO(apaszke): Changing an offset might add or remove one vreg. if (dst_tiles_shape != src_tiles.dimensions()) { - return emitError(v.getLoc(), "Offsets changing the vreg array shape"); + return emitError( + v.getLoc(), "Not implemented: Offsets changing the vreg array shape"); } xla::Array dst_tiles = src_tiles; @@ -3346,7 +3355,7 @@ FailureOr relayout(RewriteContext &ctx, OpBuilder &builder, Value v, if (!src.offsets()[0].has_value()) { row_diff = 0; } else if (!dst.offsets()[0].has_value()) { - return emitError(v.getLoc(), "Sublane broadcast not implemented"); + return emitError(v.getLoc(), "Not implemented: Sublane broadcast"); } else { row_diff = *dst.offsets()[0] - *src.offsets()[0]; } @@ -3356,7 +3365,8 @@ FailureOr relayout(RewriteContext &ctx, OpBuilder &builder, Value v, const SmallVector implicit_shape = src.implicitShape(vty.getShape()); if (implicit_shape[implicit_shape.size() - 2] != 1) { - return emitError(v.getLoc(), "Row shifts for multi-row values"); + return emitError(v.getLoc(), + "Not implemented: Row shifts for multi-row values"); } const int64_t src_sublane = *src.offsets()[0] / packing; const int64_t dst_sublane = *dst.offsets()[0] / packing; @@ -3454,7 +3464,8 @@ FailureOr relayout(RewriteContext &ctx, OpBuilder &builder, Value v, return assemble(ctx, builder, vty, dst, std::move(dst_tiles)).getResult(); } // TODO(apaszke): Implement general relayout - return emitError(v.getLoc(), "unsupported layout change for ") + return emitError(v.getLoc(), + "Not implemented: Unsupported layout change for ") << vty << ": " << src << " -> " << dst; } @@ -3497,7 +3508,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { auto vty = dyn_cast(operand.getType()); if ((vty == nullptr) == li.has_value()) { return op.emitError( - "layout should be none iff operand is not a vector"); + "Layout should be none iff operand is not a vector"); } if (vty == nullptr) { continue; @@ -3508,7 +3519,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // arguments. auto op_result = dyn_cast(operand); if (op_result == nullptr) { - return op.emitError("expected operand to be an operation result"); + return op.emitError("Expected operand to be an operation result"); } Operation *const def_op = op_result.getOwner(); CHECK(def_op); @@ -3517,7 +3528,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { getOutLayout(*def_op)); const Layout lo = def_layouts[res_idx]; if (!lo.has_value()) { - return op.emitError() << "vector result should have a defined layout"; + return op.emitError() << "Vector result should have a defined layout"; } if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) { continue; @@ -3553,7 +3564,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { if (OpTrait::hasElementwiseMappableTraits(&op)) { return elementwise_op_rule(ctx, op, layout_in, layout_out); } - return op.emitError("Unsupported operation: ") << op.getName(); + return op.emitError("Not implemented: Unsupported operation: ") + << op.getName(); } LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block) {