Skip to content

Commit

Permalink
Add a register_custom_type_id function to the GPU plugins.
Browse files Browse the repository at this point in the history
This enables dynamic registration of custom FFI types on the appropriate platform via PJRT.

PiperOrigin-RevId: 712904085
  • Loading branch information
dfm authored and Google-ML-Automation committed Jan 7, 2025
1 parent 853af56 commit a7f384c
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 13 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Support added for user defined state in the FFI via the new
{func}`jax.ffi.register_ffi_type_id` function.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
Expand Down
2 changes: 1 addition & 1 deletion docs/jax.ffi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ffi_lowering
pycapsule
register_ffi_target
register_ffi_type_id


``jax.extend.ffi`` module (deprecated)
Expand All @@ -28,4 +29,3 @@ the legacy import is being deprecated.
ffi_lowering
pycapsule
register_ffi_target

15 changes: 15 additions & 0 deletions jax/_src/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,21 @@ def register_ffi_target(
**kwargs)


def register_ffi_type_id(
name: str,
obj: Any,
platform: str = "cpu",
) -> None:
"""Registers a custom type ID for a FFI target.
Args:
name: the name of the type ID. This name must be unique within the process.
obj: a ``PyCapsule`` object encapsulating a pointer to the type ID.
platform: the target platform.
"""
return xla_client.register_custom_type_id(name, obj, platform=platform)


def pycapsule(funcptr):
"""Wrap a ctypes function pointer in a PyCapsule.
Expand Down
1 change: 1 addition & 0 deletions jax/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,5 @@
include_dir as include_dir,
pycapsule as pycapsule,
register_ffi_target as register_ffi_target,
register_ffi_type_id as register_ffi_type_id,
)
6 changes: 6 additions & 0 deletions jax_plugins/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,11 @@ def initialize():
)
for _name, _value in cuda_plugin_extension.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
xla_client.register_custom_type_id_handler(
"CUDA",
functools.partial(
cuda_plugin_extension.register_custom_type_id, c_api
),
)
else:
logger.warning('cuda_plugin_extension is not found.')
6 changes: 6 additions & 0 deletions jax_plugins/rocm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,11 @@ def initialize():
)
for _name, _value in rocm_plugin_extension.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="ROCM")
xla_client.register_custom_type_id_handler(
"ROCM",
functools.partial(
rocm_plugin_extension.register_custom_type_id, c_api
),
)
else:
logger.warning('rocm_plugin_extension is not found.')
1 change: 1 addition & 0 deletions jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ cc_library(
"@xla//xla:util",
"@xla//xla/ffi/api:c_api",
"@xla//xla/pjrt:status_casters",
"@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_hdrs",
"@xla//xla/pjrt/c:pjrt_c_api_helpers",
Expand Down
56 changes: 44 additions & 12 deletions jaxlib/gpu_plugin_extension.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.
#include "jaxlib/kernel_nanobind_helpers.h"
#include "xla/ffi/api/c_api.h"
#include "xla/pjrt/c/pjrt_c_api.h"
#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h"
#include "xla/pjrt/c/pjrt_c_api_gpu_extension.h"
#include "xla/pjrt/c/pjrt_c_api_helpers.h"
#include "xla/pjrt/status_casters.h"
Expand All @@ -44,21 +45,14 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
size_t fn_name_size, nb::object fn,
int api_version,
XLA_FFI_Handler_Traits traits) {
if (c_api->extension_start == nullptr) {
return Unimplemented("The plugin does not have extension.");
}
const PJRT_Extension_Base* next =
reinterpret_cast<const PJRT_Extension_Base*>(c_api->extension_start);
while (next != nullptr &&
next->type !=
PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) {
next = next->next;
}
if (next == nullptr) {
const PJRT_Gpu_Custom_Call* custom_call_ext =
pjrt::FindExtension<PJRT_Gpu_Custom_Call>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call);
if (custom_call_ext == nullptr) {
return Unimplemented("The plugin does not have a custom call extension.");
}
PJRT_Gpu_Register_Custom_Call* register_custom_call =
reinterpret_cast<const PJRT_Gpu_Custom_Call*>(next)->custom_call;
custom_call_ext->custom_call;

if (traits != 0) {
return Unimplemented("The plugin does not support custom call traits.");
Expand Down Expand Up @@ -137,6 +131,34 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api,
#endif
}

absl::Status RegisterCustomTypeId(const PJRT_Api* c_api,
const char* type_name_c_str,
size_t type_name_size, nb::object type_id) {
const PJRT_FFI_Extension* ffi_ext = pjrt::FindExtension<PJRT_FFI_Extension>(
c_api, PJRT_Extension_Type::PJRT_Extension_Type_FFI);
if (ffi_ext == nullptr) {
return Unimplemented("The plugin does not have the FFI extension.");
}

PJRT_FFI_TypeID_Register_Args args;
args.struct_size = PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE;
args.type_name = type_name_c_str;
args.type_name_size = type_name_size;
RETURN_STATUS_IF_PJRT_ERROR(ffi_ext->type_id_register(&args), c_api);

nb::capsule capsule;
if (!nb::try_cast<nb::capsule>(type_id, capsule)) {
return absl::InvalidArgumentError(
"The type_id argument to register_custom_call_type_id must be a "
"PyCapsule object holding a pointer to a XLA_FFI_TypeId.");
}
XLA_FFI_TypeId* type_id_ptr =
reinterpret_cast<XLA_FFI_TypeId*>(static_cast<void*>(capsule.data()));
type_id_ptr->type_id = args.type_id;

return absl::OkStatus();
}

nb::dict Registrations() {
nb::dict dict;
dict["xla_python_gpu_callback"] =
Expand Down Expand Up @@ -171,6 +193,16 @@ void BuildGpuPluginExtension(nanobind::module_& m) {
nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"),
nb::arg("xla_platform_name"), nb::arg("api_version") = 0,
nb::arg("traits") = 0);
m.def(
"register_custom_type_id",
[](nb::capsule c_api, nb::str type_name_py, nb::object type_id) {
const char* type_name_c_str = type_name_py.c_str();
size_t type_name_size = nb::len(type_name_py);
xla::ThrowIfError(RegisterCustomTypeId(
static_cast<const PJRT_Api*>(c_api.data()), type_name_c_str,
type_name_size, std::move(type_id)));
},
nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id"));
m.def("registrations", &Registrations);
}

Expand Down

0 comments on commit a7f384c

Please sign in to comment.