Skip to content

Commit

Permalink
[Mosaic TPU] Be much more aggressive in inferring large 2nd minor lay…
Browse files Browse the repository at this point in the history
…outs for 16-bit types on v6

This often lets us avoid ambiguities between selecting the (8, 128) and (16, 128) tiling,
by biasing the layout inference to prefer the latter.

PiperOrigin-RevId: 713270421
  • Loading branch information
apaszke authored and Google-ML-Automation committed Jan 8, 2025
1 parent 5fd1b2f commit f96339b
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 21 deletions.
1 change: 1 addition & 0 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def _lower_tpu_kernel(
pipeline = [
(
"func.func(tpu-infer-vector-layout{"
f" hardware-generation={hardware_generation}"
f" sublane-count={sl_cnt} lane-count={l_cnt}"
"})"
),
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,7 @@ def InferVectorLayoutPass : Pass<"tpu-infer-vector-layout", "::mlir::func::FuncO
];
let constructor = "::mlir::tpu::createInferVectorLayoutPass()";
let options = [
Option<"hardware_generation", "hardware-generation", "int", /*default=*/"-1", "">,
Option<"lane_count", "lane-count", "int", /*default=*/"128", "">,
Option<"sublane_count", "sublane-count", "int", /*default=*/"8", "">,
];
Expand Down
4 changes: 3 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ std::unique_ptr<OperationPass<func::FuncOp>> createCanonicalizeMosaicPass(
int hardware_generation = -1);

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
std::array<int64_t, 2> target_shape = {8, 128});
int hardware_generation = -1,
std::array<int64_t, 2> target_shape = {8, 128},
const TpuTilingFlags &tpu_tiling_flags = {});

std::unique_ptr<OperationPass<func::FuncOp>> createRelayoutInsertionPass(
std::array<int64_t, 2> target_shape = {8, 128});
Expand Down
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ FailureOr<BlockArgument> appendConstant(RewriteContext &ctx, func::FuncOp func,
MemRefType arg_type,
inferMemref(
MemRefType::get(value_ty.getShape(), value_ty.getElementType()),
ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{}));
ctx.hardware_generation, ctx.target_shape, /*tpu_tiling_flags=*/{},
/*is_kernel_argument=*/true));
const BlockArgument argument = entry_block.insertArgument(
entry_block.getNumArguments() - 1, arg_type, UnknownLoc::get(mlir_ctx));
const FunctionType func_ty = func.getFunctionType();
Expand Down
32 changes: 22 additions & 10 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ namespace mlir::tpu {
// tpu_tiling_flags: A struct of flags indicating which large tiling modes are
// enabled by XLA for memrefs.
// bitwidth: The bitwidth of the element type of the operand.
// is_kernel_argument: Whether the operand is a kernel argument.
int getTilingFactor(const int num_lanes, const int hardware_generation,
const int64_t sublane_count,
const TpuTilingFlags &tpu_tiling_flags,
const int8_t bitwidth) {
const int8_t bitwidth,
const bool is_kernel_argument) {
CHECK(llvm::isPowerOf2_32(bitwidth));
CHECK_LE(4, bitwidth);
CHECK_LE(bitwidth, 32);
Expand All @@ -61,7 +63,11 @@ int getTilingFactor(const int num_lanes, const int hardware_generation,
if (bitwidth == 8 && tpu_tiling_flags.use_x8_large_second_minor) {
return sublane_count * 4;
}
if (bitwidth == 16 && tpu_tiling_flags.use_x16_large_second_minor) {
// 16-bit values are generally always possible to relayout on the fly in v6,
// so we allow large 2nd minor tiling whenever possible. We can't do this
// for kernel arguments, because the layout of those is controlled by XLA.
if (bitwidth == 16 && (tpu_tiling_flags.use_x16_large_second_minor ||
(!is_kernel_argument && hardware_generation >= 6))) {
return sublane_count * 2;
}
return sublane_count;
Expand All @@ -84,6 +90,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags,
bool is_kernel_argument,
int64_t leading_tile_rows = 0) {
if (auto tiled_layout_attr =
dyn_cast<TiledLayoutAttr>(memref_ty.getLayout())) {
Expand Down Expand Up @@ -119,7 +126,8 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
const int64_t leading_tile =
getTilingFactor(
llvm::divideCeil(memref_ty.getShape().back(), lane_count),
hardware_generation, sublane_count, tpu_tiling_flags, bitwidth) *
hardware_generation, sublane_count, tpu_tiling_flags, bitwidth,
is_kernel_argument) *
lane_count;
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile})};
if (bitwidth != 32) {
Expand All @@ -139,7 +147,7 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
if (leading_tile_rows == 0) {
leading_tile_rows =
getTilingFactor(second_minor, hardware_generation, sublane_count,
tpu_tiling_flags, bitwidth);
tpu_tiling_flags, bitwidth, is_kernel_argument);
}
SmallVector<xla::Tile> tiles{xla::Tile({leading_tile_rows, lane_count})};
if (bitwidth != 32) {
Expand Down Expand Up @@ -186,6 +194,7 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
const int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags,
bool is_kernel_argument,
int64_t leading_tile_rows) {
if (isa<SemaphoreType, DMASemaphoreType>(memref.getElementType())) {
const Attribute semaphore_mem = tpu::MemorySpaceAttr::get(
Expand All @@ -209,7 +218,7 @@ FailureOr<MemRefType> inferMemref(MemRefType memref,
FAILUREOR_ASSIGN_OR_RETURN(
const TiledLayoutAttr layout,
inferLayout(memref, hardware_generation, target_shape, tpu_tiling_flags,
leading_tile_rows));
is_kernel_argument, leading_tile_rows));

const ArrayRef<xla::Tile> tiles = layout.getTiles();
if (failed(checkTiles(memref.getContext(), tiles))) {
Expand Down Expand Up @@ -248,7 +257,8 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, target_shape,
tpu_tiling_flags, leading_tile_rows));
tpu_tiling_flags, /*is_kernel_argument=*/false,
leading_tile_rows));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
Expand All @@ -265,9 +275,10 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
} else if (auto alloca_op = dyn_cast<tpu::AllocaSemaphoreOp>(op)) {
TypedValue<MemRefType> arg = alloca_op.getResult();
const MemRefType memref_ty = alloca_op.getResult().getType();
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation,
target_shape, tpu_tiling_flags));
FAILUREOR_ASSIGN_OR_RETURN(
const MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, target_shape,
tpu_tiling_flags, /*is_kernel_argument=*/false));
alloca_op.getResult().setType(new_memref_ty);
if (memref_ty != new_memref_ty) {
OpBuilder builder(alloca_op->getContext());
Expand Down Expand Up @@ -320,7 +331,8 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
FAILUREOR_ASSIGN_OR_RETURN(
MemRefType new_memref_ty,
inferMemref(memref_ty, hardware_generation, target_shape,
tpu_tiling_flags, leading_tile_rows));
tpu_tiling_flags, /*is_kernel_argument=*/true,
leading_tile_rows));
arg.setType(new_memref_ty);
new_arg_types.push_back(arg.getType());
if (memref_ty != new_memref_ty) {
Expand Down
1 change: 1 addition & 0 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace mlir::tpu {
FailureOr<MemRefType> inferMemref(MemRefType memref, int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags& tpu_tiling_flags,
bool is_kernel_argument,
int64_t leading_tile_rows = 0);

const std::string_view kLeadingTileRows = "leading_tile_rows";
Expand Down
52 changes: 43 additions & 9 deletions jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,13 @@ LogicalResult verifyDivisibleIndex(Value tiled_index, int64_t tiling, int dim,
// have corresponding native instructions.
class VectorLayoutInferer {
public:
explicit VectorLayoutInferer(std::array<int64_t, 2> target_shape)
: target_shape_({target_shape[0], target_shape[1]}),
default_tiling_(target_shape) {}
explicit VectorLayoutInferer(int hardware_generation,
std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags)
: hardware_generation_(hardware_generation),
target_shape_({target_shape[0], target_shape[1]}),
default_tiling_(target_shape),
tpu_tiling_flags_(tpu_tiling_flags) {}

#define TPU_CHECK_OP(cond, msg) \
if (!(cond)) { \
Expand Down Expand Up @@ -1703,6 +1707,21 @@ class VectorLayoutInferer {
}
auto &layout = *some_layout;
bool select_native = allUsersRequireNativeTiling(op->getResult(0));
// We might want to reconsider enabling native this aggressively in cases
// when it would introduce a lot of padding (e.g. when the value only has
// a small second minor size, but large minor size).
if (dst_ty.getElementTypeBitWidth() == 16) {
// TPUv6 has good support for compute in 16-bit and cheap retiling between
// large 2nd minor and the default tiling, so we bias towards large tiles.
select_native |= hardware_generation_ >= 6 ||
tpu_tiling_flags_.use_x16_large_second_minor;
} else if (dst_ty.getElementTypeBitWidth() == 8) {
select_native |= tpu_tiling_flags_.use_x8_large_second_minor;
} else if (dst_ty.getElementTypeBitWidth() == 4) {
select_native |= tpu_tiling_flags_.use_x4_large_second_minor;
} else {
return op->emitOpError("Unsupported target bitwidth for truncation");
}
auto src_layout = VectorLayout(32, layout.offsets(), default_tiling_,
layout.implicit_dim());
auto dst_layout = VectorLayout(
Expand Down Expand Up @@ -2017,40 +2036,55 @@ class VectorLayoutInferer {
default_tiling_[1]};
}

int hardware_generation_;
std::array<int64_t, 2> target_shape_;
std::array<int64_t, 2> default_tiling_;
TpuTilingFlags tpu_tiling_flags_;

// TODO(b/342235360): Deprecate force_first_tile_offsets_ once we fully
// remove the restriction that offsets must fall within the first tile.
bool force_first_tile_offsets_ = false;

// Address alignment requirement, counted in 32-bit increments.
static constexpr int64_t kVmemAlignment32 = 128;
// TODO(apaszke): This is not really native on newer generations of TPUs.
// Get rid of this temporary stopgap.
static constexpr int8_t kNativeBitwidth = 32;
};

struct InferVectorLayoutPass
: public impl::InferVectorLayoutPassBase<InferVectorLayoutPass> {
InferVectorLayoutPass(std::array<int64_t, 2> target_shape) {
InferVectorLayoutPass(int hardware_generation,
std::array<int64_t, 2> target_shape,
TpuTilingFlags tpu_tiling_flags) {
this->hardware_generation = hardware_generation;
this->sublane_count = target_shape[0];
this->lane_count = target_shape[1];
this->tpu_tiling_flags = tpu_tiling_flags;
}
void runOnOperation() override {
// Fail if hardware_generation has not been set from the default value.
if (hardware_generation < 0) {
getOperation().emitError("hardware_generation must be set") << hardware_generation;
signalPassFailure();
return;
}
func::FuncOp func = getOperation();
VectorLayoutInferer run({sublane_count, lane_count});
VectorLayoutInferer run(hardware_generation, {sublane_count, lane_count},
tpu_tiling_flags);
if (run.infer(func).failed()) {
signalPassFailure();
}
}

TpuTilingFlags tpu_tiling_flags;
};

} // namespace

std::unique_ptr<OperationPass<func::FuncOp>> createInferVectorLayoutPass(
std::array<int64_t, 2> target_shape) {
return std::make_unique<InferVectorLayoutPass>(target_shape);
int hardware_generation, std::array<int64_t, 2> target_shape,
const TpuTilingFlags &tpu_tiling_flags) {
return std::make_unique<InferVectorLayoutPass>(
hardware_generation, target_shape, tpu_tiling_flags);
}

} // namespace mlir::tpu

0 comments on commit f96339b

Please sign in to comment.