From a7f384cc6ee9ebcbcfcf90b9aca405f7e91e676e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Tue, 7 Jan 2025 07:28:21 -0800 Subject: [PATCH] Add a register_custom_type_id function to the GPU plugins. This enables dynamic registration of custom FFI types on the appropriate platform via PJRT. PiperOrigin-RevId: 712904085 --- CHANGELOG.md | 2 ++ docs/jax.ffi.rst | 2 +- jax/_src/ffi.py | 15 +++++++++ jax/ffi.py | 1 + jax_plugins/cuda/__init__.py | 6 ++++ jax_plugins/rocm/__init__.py | 6 ++++ jaxlib/BUILD | 1 + jaxlib/gpu_plugin_extension.cc | 56 ++++++++++++++++++++++++++-------- 8 files changed, 76 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d36ee67ec0dd..346c399b3332 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. {func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support transforms in more than 3 dimensions, which was previously the limit. See {jax-issue}`#25606` for more details. + * Support added for user defined state in the FFI via the new + {func}`jax.ffi.register_ffi_type_id` function. * Deprecations * From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings` diff --git a/docs/jax.ffi.rst b/docs/jax.ffi.rst index cc16b1d5b768..dc2c6f8ac873 100644 --- a/docs/jax.ffi.rst +++ b/docs/jax.ffi.rst @@ -10,6 +10,7 @@ ffi_lowering pycapsule register_ffi_target + register_ffi_type_id ``jax.extend.ffi`` module (deprecated) @@ -28,4 +29,3 @@ the legacy import is being deprecated. ffi_lowering pycapsule register_ffi_target - diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index eef8b1ca99d6..46b6543b8fff 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -72,6 +72,21 @@ def register_ffi_target( **kwargs) +def register_ffi_type_id( + name: str, + obj: Any, + platform: str = "cpu", +) -> None: + """Registers a custom type ID for a FFI target. + + Args: + name: the name of the type ID. This name must be unique within the process. + obj: a ``PyCapsule`` object encapsulating a pointer to the type ID. + platform: the target platform. + """ + return xla_client.register_custom_type_id(name, obj, platform=platform) + + def pycapsule(funcptr): """Wrap a ctypes function pointer in a PyCapsule. diff --git a/jax/ffi.py b/jax/ffi.py index 529818ff59da..1f9be6e8c4b3 100644 --- a/jax/ffi.py +++ b/jax/ffi.py @@ -21,4 +21,5 @@ include_dir as include_dir, pycapsule as pycapsule, register_ffi_target as register_ffi_target, + register_ffi_type_id as register_ffi_type_id, ) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index f7d4f9a191cb..a09e21c6dd77 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -93,5 +93,11 @@ def initialize(): ) for _name, _value in cuda_plugin_extension.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="CUDA") + xla_client.register_custom_type_id_handler( + "CUDA", + functools.partial( + cuda_plugin_extension.register_custom_type_id, c_api + ), + ) else: logger.warning('cuda_plugin_extension is not found.') diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 8b176b675b88..b16806e396fc 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -94,5 +94,11 @@ def initialize(): ) for _name, _value in rocm_plugin_extension.registrations().items(): xla_client.register_custom_call_target(_name, _value, platform="ROCM") + xla_client.register_custom_type_id_handler( + "ROCM", + functools.partial( + rocm_plugin_extension.register_custom_type_id, c_api + ), + ) else: logger.warning('rocm_plugin_extension is not found.') diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 4ef60ac0fcdc..e69432e89384 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -238,6 +238,7 @@ cc_library( "@xla//xla:util", "@xla//xla/ffi/api:c_api", "@xla//xla/pjrt:status_casters", + "@xla//xla/pjrt/c:pjrt_c_api_ffi_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_gpu_extension_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_hdrs", "@xla//xla/pjrt/c:pjrt_c_api_helpers", diff --git a/jaxlib/gpu_plugin_extension.cc b/jaxlib/gpu_plugin_extension.cc index 9eecc1d1d637..46263bdcd40c 100644 --- a/jaxlib/gpu_plugin_extension.cc +++ b/jaxlib/gpu_plugin_extension.cc @@ -25,6 +25,7 @@ limitations under the License. #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/ffi/api/c_api.h" #include "xla/pjrt/c/pjrt_c_api.h" +#include "xla/pjrt/c/pjrt_c_api_ffi_extension.h" #include "xla/pjrt/c/pjrt_c_api_gpu_extension.h" #include "xla/pjrt/c/pjrt_c_api_helpers.h" #include "xla/pjrt/status_casters.h" @@ -44,21 +45,14 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, size_t fn_name_size, nb::object fn, int api_version, XLA_FFI_Handler_Traits traits) { - if (c_api->extension_start == nullptr) { - return Unimplemented("The plugin does not have extension."); - } - const PJRT_Extension_Base* next = - reinterpret_cast(c_api->extension_start); - while (next != nullptr && - next->type != - PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call) { - next = next->next; - } - if (next == nullptr) { + const PJRT_Gpu_Custom_Call* custom_call_ext = + pjrt::FindExtension( + c_api, PJRT_Extension_Type::PJRT_Extension_Type_Gpu_Custom_Call); + if (custom_call_ext == nullptr) { return Unimplemented("The plugin does not have a custom call extension."); } PJRT_Gpu_Register_Custom_Call* register_custom_call = - reinterpret_cast(next)->custom_call; + custom_call_ext->custom_call; if (traits != 0) { return Unimplemented("The plugin does not support custom call traits."); @@ -137,6 +131,34 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, #endif } +absl::Status RegisterCustomTypeId(const PJRT_Api* c_api, + const char* type_name_c_str, + size_t type_name_size, nb::object type_id) { + const PJRT_FFI_Extension* ffi_ext = pjrt::FindExtension( + c_api, PJRT_Extension_Type::PJRT_Extension_Type_FFI); + if (ffi_ext == nullptr) { + return Unimplemented("The plugin does not have the FFI extension."); + } + + PJRT_FFI_TypeID_Register_Args args; + args.struct_size = PJRT_FFI_TypeID_Register_Args_STRUCT_SIZE; + args.type_name = type_name_c_str; + args.type_name_size = type_name_size; + RETURN_STATUS_IF_PJRT_ERROR(ffi_ext->type_id_register(&args), c_api); + + nb::capsule capsule; + if (!nb::try_cast(type_id, capsule)) { + return absl::InvalidArgumentError( + "The type_id argument to register_custom_call_type_id must be a " + "PyCapsule object holding a pointer to a XLA_FFI_TypeId."); + } + XLA_FFI_TypeId* type_id_ptr = + reinterpret_cast(static_cast(capsule.data())); + type_id_ptr->type_id = args.type_id; + + return absl::OkStatus(); +} + nb::dict Registrations() { nb::dict dict; dict["xla_python_gpu_callback"] = @@ -171,6 +193,16 @@ void BuildGpuPluginExtension(nanobind::module_& m) { nb::arg("c_api"), nb::arg("fn_name"), nb::arg("fn"), nb::arg("xla_platform_name"), nb::arg("api_version") = 0, nb::arg("traits") = 0); + m.def( + "register_custom_type_id", + [](nb::capsule c_api, nb::str type_name_py, nb::object type_id) { + const char* type_name_c_str = type_name_py.c_str(); + size_t type_name_size = nb::len(type_name_py); + xla::ThrowIfError(RegisterCustomTypeId( + static_cast(c_api.data()), type_name_c_str, + type_name_size, std::move(type_id))); + }, + nb::arg("c_api"), nb::arg("type_name"), nb::arg("type_id")); m.def("registrations", &Registrations); }