Skip to content

Commit

Permalink
[Mosaic] apply_vector_layout C++: Be consistent about using "Not impl…
Browse files Browse the repository at this point in the history
…emented" as a prefix for error messages

I want to rely on this in the Python bindings to raise `NotImplementedError` exceptions.

PiperOrigin-RevId: 575897758
  • Loading branch information
tlongeri authored and jax authors committed Oct 23, 2023
1 parent 9bc0439 commit 20e5838
Showing 1 changed file with 44 additions and 32 deletions.
76 changes: 44 additions & 32 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@
// TODO(tlongeri): Prefer returning failure over CHECKs. In particular, be more
// consistent about this for layout null checks in rules.

#define NYI(msg) \
op->emitOpError("not implemented: " msg); \
return failure();

namespace mlir::tpu {
// TODO(tlongeri): Maybe just roll our own multi-dimensional array instead of
// using XLA's? There's too much glue for going from/to ArrayRef.
Expand Down Expand Up @@ -354,7 +350,8 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx,
Block &entry_block = ctx.func.getBody().front();
auto value_ty = cast<VectorType>(value.getType());
if (value_ty.getElementType().getIntOrFloatBitWidth() != 32) {
return ctx.func.emitOpError("Only 32-bit constants supported");
return ctx.func.emitOpError(
"Not implemented: Only 32-bit constants supported");
}
if (ctx.func->getAttr("scratch_operands")) {
return ctx.func.emitOpError(
Expand Down Expand Up @@ -514,7 +511,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
if (!llvm::all_of(layouts_in, [&](const Layout &l) {
return l->generalizes(layout_out, out_ty.getShape(), ctx.target_shape);
})) {
return op.emitOpError("Incompatible layouts in elementwise operation");
return op.emitOpError(
"Not implemented: Incompatible layouts in elementwise operation");
}
const unsigned num_operands = op.getNumOperands();
SmallVector<xla::Array<Value>> in_vreg_arrays;
Expand Down Expand Up @@ -585,7 +583,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation());
auto result_ty = cast<VectorType>(op.getResult().getType());
if (layout_out.bitwidth() != 32) {
return op.emitOpError("Only extensions to 32-bit supported");
return op.emitOpError(
"Not implemented: Only extensions to 32-bit supported");
}
FAILUREOR_ASSIGN_OR_RETURN(const xla::Array<Value> input_vregs,
disassemble(ctx, builder, layout_in, op.getIn()));
Expand All @@ -600,7 +599,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op,
switch (layout_in.implicit_dim()) {
case VectorLayout::ImplicitDim::kNone: {
if (layout_in.tiling() != layout_out.tiling()) {
return op.emitOpError("Changing tiling during extension");
return op.emitOpError(
"Not implemented: Changing tiling during the cast");
}
auto tiling = layout_in.tiling();
if (ctx.target_shape[0] % tiling[0] != 0 ||
Expand Down Expand Up @@ -648,7 +648,8 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op,
auto extf_op = cast<arith::ExtFOp>(op);
if (layouts_in.front()->bitwidth() != 16 ||
layouts_out.front()->bitwidth() != 32) {
return op.emitOpError("Only 16-bit to 32-bit conversion supported");
return op.emitOpError(
"Not implemented: Only 16-bit to 32-bit conversion supported");
}
return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(),
*layouts_out.front());
Expand Down Expand Up @@ -677,7 +678,7 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
xla::Array<Value> output_vregs(
layout_out.tileArrayShape(result_ty.getShape(), ctx.target_shape));
if (layout_in.bitwidth() != 32) {
return op.emitOpError("Only 32-bit truncation supported");
return op.emitOpError("Not implemented: Only 32-bit truncation supported");
}
FAILUREOR_ASSIGN_OR_RETURN(
VectorType res_vreg_ty,
Expand Down Expand Up @@ -1257,8 +1258,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
if (!layout.hasNaturalTopology(ctx.target_shape) ||
layout.offsets() != LayoutOffsets{0, 0}) {
return op.emitOpError(
"Only native tiling with offset (0, 0) is supported when "
"concatenation along tiling dims.");
"Not implemented: Only native tiling with offset (0, 0) is supported "
"when concatenation along tiling dims.");
}
// Check if shapes of src and res are aligned to native tiling.
auto check_aligned = [&](const VectorType &vty) {
Expand All @@ -1274,8 +1275,8 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op,
}
if (!is_aligned) {
return op.emitOpError(
"Only aligned shapes are supported when concatenation along tiling "
"dims");
"Not implemented: Only aligned shapes are supported when "
"concatenation along tiling dims");
}
}

Expand Down Expand Up @@ -1582,12 +1583,14 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
AffineMap load_map;
arith::ConstantOp padding;
if (offsets[1] == std::nullopt) {
return op.emitOpError("Load replicated along lanes is unsupported");
return op.emitOpError(
"Not implemented: Load replicated along lanes is unsupported");
}
if (offsets[0] == std::nullopt) {
if (ss != 1) {
return op.emitOpError(
"Sublane-replicated load with size > 1 is unsupported");
"Not implemented: Sublane-replicated load with size > 1 is "
"unsupported");
}
if (!layout_out.hasNativeTiling(ctx.target_shape)) {
return op.emitOpError("Not implemented");
Expand Down Expand Up @@ -1692,7 +1695,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
getNativeVregType(vty.getElementType(), ctx.target_shape));
if (value.isSplat()) {
if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) {
return op.emitOpError("Non-replicated splat constants");
return op.emitOpError(
"Not implemented: Non-replicated splat constants");
}
auto new_value =
DenseElementsAttr::get(target_vty, value.getSplatValue<Attribute>());
Expand All @@ -1708,7 +1712,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
}
// !value.isSplat()
if (getTypeBitwidth<true>(vty.getElementType()) != 32) {
return op.emitOpError("Only 32-bit non-splat constants are supported");
return op.emitOpError(
"Not implemented: Only 32-bit non-splat constants are supported");
}
FAILUREOR_ASSIGN_OR_RETURN(const BlockArgument ref,
appendConstant(ctx, value));
Expand All @@ -1722,7 +1727,7 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op,
{VectorLayout(/*bitwidth=*/32, /*offsets=*/{0, 0},
/*tiling=*/ctx.target_shape)});
}
return op.emitOpError("Unsupported arith.const type: ")
return op.emitOpError("Not implemented: Unsupported arith.const type: ")
<< op.getResult(0).getType();
}

Expand Down Expand Up @@ -1959,7 +1964,7 @@ LogicalResult vector_contract_rule(RewriteContext &ctx, Operation &op,
if (indexing_maps != matmul_indexing_maps &&
indexing_maps != matmul_indexing_maps_transposed) {
return vector_contract_op->emitOpError(
"Non-matmul or unsupported indexing_maps");
"Not implemented: Non-matmul or unsupported indexing_maps");
}
const bool transpose_rhs = indexing_maps == matmul_indexing_maps_transposed;
const ArrayAttr matmul_iterator_types =
Expand Down Expand Up @@ -2125,15 +2130,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
"Not implemented: unsupported kind");
}
if (val != neutral.getValueAsDouble()) {
return multi_reduction_op.emitOpError("Only neutral accumulator supported");
return multi_reduction_op.emitOpError(
"Not implemented: Only neutral accumulator supported");
}

if (src_layout.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
src_layout.hasNaturalTopology(ctx.target_shape)) {
auto [sublane_offset, lane_offset] = src_layout.offsets();
if (dim < 0) {
return multi_reduction_op.emitOpError(
"Negative reduction dimension unsupported");
"Not implemented: Negative reduction dimension unsupported");
}
int64_t vdim;
Direction reduce_over;
Expand Down Expand Up @@ -2252,7 +2258,9 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
multi_reduction_op->erase();
return success();
}
return multi_reduction_op->emitOpError("Unsupported layout: ") << src_layout;
return multi_reduction_op->emitOpError(
"Not implemented: Unsupported layout: ")
<< src_layout;
}

LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
Expand Down Expand Up @@ -2826,7 +2834,7 @@ FailureOr<xla::Array<Value>> disassemble(RewriteContext &ctx,
val.getLoc(), SmallVector<Type>(num_vectors, vreg_ty), val);
return XlaArrayFromShapeAndValues<Value>(layout_shape, u->getResults());
}
return op->emitOpError("unimplemented: ") << val;
return op->emitOpError("Not implemented: ") << val;
}

// Assembles a destination tile using partial data from rotated vregs using a
Expand Down Expand Up @@ -3337,7 +3345,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
const auto &tiling = src.tiling();
// TODO(apaszke): Changing an offset might add or remove one vreg.
if (dst_tiles_shape != src_tiles.dimensions()) {
return emitError(v.getLoc(), "Offsets changing the vreg array shape");
return emitError(
v.getLoc(), "Not implemented: Offsets changing the vreg array shape");
}
xla::Array<Value> dst_tiles = src_tiles;

Expand All @@ -3346,7 +3355,7 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
if (!src.offsets()[0].has_value()) {
row_diff = 0;
} else if (!dst.offsets()[0].has_value()) {
return emitError(v.getLoc(), "Sublane broadcast not implemented");
return emitError(v.getLoc(), "Not implemented: Sublane broadcast");
} else {
row_diff = *dst.offsets()[0] - *src.offsets()[0];
}
Expand All @@ -3356,7 +3365,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
const SmallVector<int64_t> implicit_shape =
src.implicitShape(vty.getShape());
if (implicit_shape[implicit_shape.size() - 2] != 1) {
return emitError(v.getLoc(), "Row shifts for multi-row values");
return emitError(v.getLoc(),
"Not implemented: Row shifts for multi-row values");
}
const int64_t src_sublane = *src.offsets()[0] / packing;
const int64_t dst_sublane = *dst.offsets()[0] / packing;
Expand Down Expand Up @@ -3454,7 +3464,8 @@ FailureOr<Value> relayout(RewriteContext &ctx, OpBuilder &builder, Value v,
return assemble(ctx, builder, vty, dst, std::move(dst_tiles)).getResult();
}
// TODO(apaszke): Implement general relayout
return emitError(v.getLoc(), "unsupported layout change for ")
return emitError(v.getLoc(),
"Not implemented: Unsupported layout change for ")
<< vty << ": " << src << " -> " << dst;
}

Expand Down Expand Up @@ -3497,7 +3508,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
auto vty = dyn_cast<VectorType>(operand.getType());
if ((vty == nullptr) == li.has_value()) {
return op.emitError(
"layout should be none iff operand is not a vector");
"Layout should be none iff operand is not a vector");
}
if (vty == nullptr) {
continue;
Expand All @@ -3508,7 +3519,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
// arguments.
auto op_result = dyn_cast<OpResult>(operand);
if (op_result == nullptr) {
return op.emitError("expected operand to be an operation result");
return op.emitError("Expected operand to be an operation result");
}
Operation *const def_op = op_result.getOwner();
CHECK(def_op);
Expand All @@ -3517,7 +3528,7 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
getOutLayout(*def_op));
const Layout lo = def_layouts[res_idx];
if (!lo.has_value()) {
return op.emitError() << "vector result should have a defined layout";
return op.emitError() << "Vector result should have a defined layout";
}
if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) {
continue;
Expand Down Expand Up @@ -3553,7 +3564,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
if (OpTrait::hasElementwiseMappableTraits(&op)) {
return elementwise_op_rule(ctx, op, layout_in, layout_out);
}
return op.emitError("Unsupported operation: ") << op.getName();
return op.emitError("Not implemented: Unsupported operation: ")
<< op.getName();
}

LogicalResult applyLayoutBlock(RewriteContext &ctx, Block &block) {
Expand Down

0 comments on commit 20e5838

Please sign in to comment.