Skip to content

Commit

Permalink
[Mosaic] Add the core type enum
Browse files Browse the repository at this point in the history
The new attribute allows differentiating compilation by target core.

PiperOrigin-RevId: 691531726
  • Loading branch information
naummo authored and Google-ML-Automation committed Oct 30, 2024
1 parent af14c43 commit 242e663
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
18 changes: 18 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def TPU_Dialect : Dialect {
let cppNamespace = "::mlir::tpu";
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;
let extraClassDeclaration = [{
static StringRef GetCoreTypeKey() { return "tpu.core_type"; }

static std::optional<CoreType> GetCoreTypeAttr(Operation *op);
}];
}

class TPU_Attr<string name, string mnemonic_, list<Trait> traits = []>
Expand All @@ -46,6 +51,19 @@ class TPU_Type<string name, string mnemonic_, list<Trait> traits = []>
let mnemonic = mnemonic_;
}

def TPU_CoreType : I32EnumAttr<"CoreType", "Core type", [
I32EnumAttrCase<"kTc", 0, "tc">,
I32EnumAttrCase<"kScScalarSubcore", 1, "sc_scalar_subcore">,
I32EnumAttrCase<"kScVectorSubcore", 2, "sc_vector_subcore">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tpu";
}

def TPU_CoreTypeEnum : EnumAttr<TPU_Dialect, TPU_CoreType, "core_type"> {
let assemblyFormat = "`<` $value `>`";
}

def TPU_SemaphoreType : TPU_Type<"Semaphore", "semaphore", [MemRefElementTypeInterface]>;
def TPU_DMASemaphoreType : TPU_Type<"DMASemaphore", "dma_semaphore", [MemRefElementTypeInterface]>;
def TPU_SomeSemaphoreType : AnyTypeOf<[TPU_SemaphoreType, TPU_DMASemaphoreType]>;
Expand Down
13 changes: 13 additions & 0 deletions jaxlib/mosaic/dialect/tpu/tpu_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <array>
#include <cstdint>
#include <memory>
#include <optional>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -68,6 +69,18 @@ void TPUDialect::initialize() {
>();
}

/* static */ std::optional<CoreType> TPUDialect::GetCoreTypeAttr(
Operation *op) {
Attribute attr = op->getAttr(GetCoreTypeKey());
if (attr == nullptr) {
return std::nullopt;
}
if (!mlir::isa<CoreTypeAttr>(attr)) {
return std::nullopt;
}
return mlir::cast<CoreTypeAttr>(attr).getValue();
}

void VectorLayoutAttr::print(AsmPrinter &printer) const {
printer << '<';
printer << getLayout();
Expand Down

0 comments on commit 242e663

Please sign in to comment.