From f96339be1ec333636a8365acf5a28d445dfb8251 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 8 Jan 2025 06:29:46 -0800 Subject: [PATCH] [Mosaic TPU] Be much more aggressive in inferring large 2nd minor layouts for 16-bit types on v6 This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling, by biasing the layout inference to prefer the latter. PiperOrigin-RevId: 713270421 --- jax/_src/tpu_custom_call.py | 1 + jaxlib/mosaic/dialect/tpu/tpu.td | 1 + jaxlib/mosaic/dialect/tpu/tpu_dialect.h | 4 +- .../tpu/transforms/apply_vector_layout.cc | 3 +- .../tpu/transforms/infer_memref_layout.cc | 32 ++++++++---- .../tpu/transforms/infer_memref_layout.h | 1 + .../tpu/transforms/infer_vector_layout.cc | 52 +++++++++++++++---- 7 files changed, 73 insertions(+), 21 deletions(-) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index bb92afebe8e9..b9645cbefb5e 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -377,6 +377,7 @@ def _lower_tpu_kernel( pipeline = [ ( "func.func(tpu-infer-vector-layout{" + f" hardware-generation={hardware_generation}" f" sublane-count={sl_cnt} lane-count={l_cnt}" "})" ), diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 6ef809c4cb6a..a486d8fef84d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -847,6 +847,7 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO ]; let constructor = "::mlir::tpu::createInferVectorLayoutPass()"; let options = [ + Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">, Option<"lane_count", "lane-count", "int", /*default=*/"128", "">, Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">, ]; diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 307f3582f007..0156798ca88d 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -79,7 +79,9 @@ std::unique_ptr> createCanonicalizeMosaicPass( int hardware_generation = -1); std::unique_ptr> createInferVectorLayoutPass( - std::array target_shape = {8, 128}); + int hardware_generation = -1, + std::array target_shape = {8, 128}, + const TpuTilingFlags &tpu_tiling_flags = {}); std::unique_ptr> createRelayoutInsertionPass( std::array target_shape = {8, 128}); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 84156c22611a..4c6353f8c504 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -429,7 +429,8 @@ FailureOr appendConstant(RewriteContext &ctx, func::FuncOp func, MemRefType arg_type, inferMemref( MemRefType::get(value_ty.getShape(), value_ty.getElementType()), - ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{})); + ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}, + /*is_kernel_argument=*/true)); const BlockArgument argument = entry_block.insertArgument( entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx)); const FunctionType func_ty = func.getFunctionType(); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc index 046b642f98a3..05667a847691 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc @@ -43,10 +43,12 @@ namespace mlir::tpu { // tpu_tiling_flags: A struct of flags indicating which large tiling modes are // enabled by XLA for memrefs. // bitwidth: The bitwidth of the element type of the operand. +// is_kernel_argument: Whether the operand is a kernel argument. int getTilingFactor(const int num_lanes, const int hardware_generation, const int64_t sublane_count, const TpuTilingFlags &tpu_tiling_flags, - const int8_t bitwidth) { + const int8_t bitwidth, + const bool is_kernel_argument) { CHECK(llvm::isPowerOf2_32(bitwidth)); CHECK_LE(4, bitwidth); CHECK_LE(bitwidth, 32); @@ -61,7 +63,11 @@ int getTilingFactor(const int num_lanes, const int hardware_generation, if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) { return sublane_count * 4; } - if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) { + // 16-bit values are generally always possible to relayout on the fly in v6, + // so we allow large 2nd minor tiling whenever possible. We can't do this + // for kernel arguments, because the layout of those is controlled by XLA. + if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor || + (!is_kernel_argument && hardware_generation >= 6))) { return sublane_count * 2; } return sublane_count; @@ -84,6 +90,7 @@ FailureOr inferLayout(MemRefType memref_ty, const int hardware_generation, std::array target_shape, const TpuTilingFlags &tpu_tiling_flags, + bool is_kernel_argument, int64_t leading_tile_rows = 0) { if (auto tiled_layout_attr = dyn_cast(memref_ty.getLayout())) { @@ -119,7 +126,8 @@ FailureOr inferLayout(MemRefType memref_ty, const int64_t leading_tile = getTilingFactor( llvm::divideCeil(memref_ty.getShape().back(), lane_count), - hardware_generation, sublane_count, tpu_tiling_flags, bitwidth) * + hardware_generation, sublane_count, tpu_tiling_flags, bitwidth, + is_kernel_argument) * lane_count; SmallVector tiles{xla::Tile({leading_tile})}; if (bitwidth != 32) { @@ -139,7 +147,7 @@ FailureOr inferLayout(MemRefType memref_ty, if (leading_tile_rows == 0) { leading_tile_rows = getTilingFactor(second_minor, hardware_generation, sublane_count, - tpu_tiling_flags, bitwidth); + tpu_tiling_flags, bitwidth, is_kernel_argument); } SmallVector tiles{xla::Tile({leading_tile_rows, lane_count})}; if (bitwidth != 32) { @@ -186,6 +194,7 @@ FailureOr inferMemref(MemRefType memref, const int hardware_generation, std::array target_shape, const TpuTilingFlags &tpu_tiling_flags, + bool is_kernel_argument, int64_t leading_tile_rows) { if (isa(memref.getElementType())) { const Attribute semaphore_mem = tpu::MemorySpaceAttr::get( @@ -209,7 +218,7 @@ FailureOr inferMemref(MemRefType memref, FAILUREOR_ASSIGN_OR_RETURN( const TiledLayoutAttr layout, inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags, - leading_tile_rows)); + is_kernel_argument, leading_tile_rows)); const ArrayRef tiles = layout.getTiles(); if (failed(checkTiles(memref.getContext(), tiles))) { @@ -248,7 +257,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, FAILUREOR_ASSIGN_OR_RETURN( const MemRefType new_memref_ty, inferMemref(memref_ty, hardware_generation, target_shape, - tpu_tiling_flags, leading_tile_rows)); + tpu_tiling_flags, /*is_kernel_argument=*/false, + leading_tile_rows)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); @@ -265,9 +275,10 @@ LogicalResult inferOp(Operation &op, const int hardware_generation, } else if (auto alloca_op = dyn_cast(op)) { TypedValue arg = alloca_op.getResult(); const MemRefType memref_ty = alloca_op.getResult().getType(); - FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty, - inferMemref(memref_ty, hardware_generation, - target_shape, tpu_tiling_flags)); + FAILUREOR_ASSIGN_OR_RETURN( + const MemRefType new_memref_ty, + inferMemref(memref_ty, hardware_generation, target_shape, + tpu_tiling_flags, /*is_kernel_argument=*/false)); alloca_op.getResult().setType(new_memref_ty); if (memref_ty != new_memref_ty) { OpBuilder builder(alloca_op->getContext()); @@ -320,7 +331,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation, FAILUREOR_ASSIGN_OR_RETURN( MemRefType new_memref_ty, inferMemref(memref_ty, hardware_generation, target_shape, - tpu_tiling_flags, leading_tile_rows)); + tpu_tiling_flags, /*is_kernel_argument=*/true, + leading_tile_rows)); arg.setType(new_memref_ty); new_arg_types.push_back(arg.getType()); if (memref_ty != new_memref_ty) { diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h index ed2a34793536..f2ab7c624eb1 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h @@ -14,6 +14,7 @@ namespace mlir::tpu { FailureOr inferMemref(MemRefType memref, int hardware_generation, std::array target_shape, const TpuTilingFlags& tpu_tiling_flags, + bool is_kernel_argument, int64_t leading_tile_rows = 0); const std::string_view kLeadingTileRows = "leading_tile_rows"; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 4e68f5558325..d189994d9564 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -92,9 +92,13 @@ LogicalResult verifyDivisibleIndex(Value tiled_index, int64_t tiling, int dim, // have corresponding native instructions. class VectorLayoutInferer { public: - explicit VectorLayoutInferer(std::array target_shape) - : target_shape_({target_shape[0], target_shape[1]}), - default_tiling_(target_shape) {} + explicit VectorLayoutInferer(int hardware_generation, + std::array target_shape, + const TpuTilingFlags &tpu_tiling_flags) + : hardware_generation_(hardware_generation), + target_shape_({target_shape[0], target_shape[1]}), + default_tiling_(target_shape), + tpu_tiling_flags_(tpu_tiling_flags) {} #define TPU_CHECK_OP(cond, msg) \ if (!(cond)) { \ @@ -1703,6 +1707,21 @@ class VectorLayoutInferer { } auto &layout = *some_layout; bool select_native = allUsersRequireNativeTiling(op->getResult(0)); + // We might want to reconsider enabling native this aggressively in cases + // when it would introduce a lot of padding (e.g. when the value only has + // a small second minor size, but large minor size). + if (dst_ty.getElementTypeBitWidth() == 16) { + // TPUv6 has good support for compute in 16-bit and cheap retiling between + // large 2nd minor and the default tiling, so we bias towards large tiles. + select_native |= hardware_generation_ >= 6 || + tpu_tiling_flags_.use_x16_large_second_minor; + } else if (dst_ty.getElementTypeBitWidth() == 8) { + select_native |= tpu_tiling_flags_.use_x8_large_second_minor; + } else if (dst_ty.getElementTypeBitWidth() == 4) { + select_native |= tpu_tiling_flags_.use_x4_large_second_minor; + } else { + return op->emitOpError("Unsupported target bitwidth for truncation"); + } auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_, layout.implicit_dim()); auto dst_layout = VectorLayout( @@ -2017,15 +2036,15 @@ class VectorLayoutInferer { default_tiling_[1]}; } + int hardware_generation_; std::array target_shape_; std::array default_tiling_; + TpuTilingFlags tpu_tiling_flags_; // TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully // remove the restriction that offsets must fall within the first tile. bool force_first_tile_offsets_ = false; - // Address alignment requirement, counted in 32-bit increments. - static constexpr int64_t kVmemAlignment32 = 128; // TODO(apaszke): This is not really native on newer generations of TPUs. // Get rid of this temporary stopgap. static constexpr int8_t kNativeBitwidth = 32; @@ -2033,24 +2052,39 @@ class VectorLayoutInferer { struct InferVectorLayoutPass : public impl::InferVectorLayoutPassBase { - InferVectorLayoutPass(std::array target_shape) { + InferVectorLayoutPass(int hardware_generation, + std::array target_shape, + TpuTilingFlags tpu_tiling_flags) { + this->hardware_generation = hardware_generation; this->sublane_count = target_shape[0]; this->lane_count = target_shape[1]; + this->tpu_tiling_flags = tpu_tiling_flags; } void runOnOperation() override { + // Fail if hardware_generation has not been set from the default value. + if (hardware_generation < 0) { + getOperation().emitError("hardware_generation must be set") << hardware_generation; + signalPassFailure(); + return; + } func::FuncOp func = getOperation(); - VectorLayoutInferer run({sublane_count, lane_count}); + VectorLayoutInferer run(hardware_generation, {sublane_count, lane_count}, + tpu_tiling_flags); if (run.infer(func).failed()) { signalPassFailure(); } } + + TpuTilingFlags tpu_tiling_flags; }; } // namespace std::unique_ptr> createInferVectorLayoutPass( - std::array target_shape) { - return std::make_unique(target_shape); + int hardware_generation, std::array target_shape, + const TpuTilingFlags &tpu_tiling_flags) { + return std::make_unique( + hardware_generation, target_shape, tpu_tiling_flags); } } // namespace mlir::tpu