Skip to content

Commit

Permalink
Add basic OpenReg module scaffolding with autograd (pytorch#131708)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#131708
Approved by: https://github.com/ezyang
  • Loading branch information
albanD authored and pytorchmergebot committed Aug 5, 2024
1 parent df59084 commit 3d87dfc
Show file tree
Hide file tree
Showing 8 changed files with 435 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ per-file-ignores =
torch/distributed/_functional_collectives.py: TOR901
torch/distributed/_spmd/data_parallel.py: TOR901
torch/distributed/_tensor/_collective_utils.py: TOR901
# This is a full package that happen to live within the test
# folder, so ok to skip
test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py: TOR901
optional-ascii-coding = True
exclude =
./.git,
Expand Down
13 changes: 13 additions & 0 deletions test/cpp_extensions/open_registration_extension/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core.

## How to use
Install as standalone with `python setup.py develop` (or install) from this folder.
You can run test via `python test/test_openreg.py`.

## Design principles
For simplicity anything that can be implemented from python is done so.
A real implementation will most likely want to call these different APIs from c++ directly.

The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn.

Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch


# Global properties of our device
NUM_DEVICES = 7

# Create our python implementation dict so that the C++ module
# can access it during its initialization
_IMPL_REGISTRY = {}

# Load the C++ Module
import pytorch_openreg._C # noqa: F401


# Define all the implementations in the registry
def register(fn):
_IMPL_REGISTRY[fn.__name__[1:]] = fn
return fn


@register
def _deviceCount():
return NUM_DEVICES


# Module used for our backend
class _OpenRegMod:
pass


# Set all the appropriate state on PyTorch
torch.utils.rename_privateuse1_backend("openreg")
torch._register_device_module("openreg", _OpenRegMod())

_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901


def _openreg_kernel_fallback(op, *args, **kwargs):
print("Calling ", op)
assert op is torch.ops.aten.empty.memory_format
# FIXME: this returns a cpu Tensor which is NOT ok.
return torch.empty(args[0])


_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <OpenReg.h>

// Make this a proper CPython module
static struct PyModuleDef openreg_C_module = {
PyModuleDef_HEAD_INIT,
.m_name = "pytorch_openreg._C",
};

PyMODINIT_FUNC PyInit__C(void) {
PyObject* mod = PyModule_Create(&openreg_C_module);

py::object openreg_mod = py::module_::import("pytorch_openreg");
// Only borrowed from the python side!
openreg::set_impl_registry(openreg_mod.attr("_IMPL_REGISTRY").ptr());

return mod;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#pragma once
// Shared header for OpenReg module

#include <torch/csrc/utils/pybind.h>

namespace openreg {

void set_impl_registry(PyObject* registry);

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
#include <OpenReg.h>

#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <ATen/detail/PrivateUse1HooksInterface.h>

#include <iostream>

namespace openreg {

namespace {
// Python dictionary where real implementations can be found
PyObject* py_registry;

py::function get_method(const char* name) {
return py::cast<py::dict>(py_registry)[name];
}

// C++ hooks implementation
struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {};

struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface {
OpenRegHooksInterface(OpenRegHooksArgs) {};
~OpenRegHooksInterface() override = default;

bool hasPrimaryContext(c10::DeviceIndex device_index) const override {
return get_method("hasPrimaryContext")(device_index).cast<bool>();
}
};

TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs);
C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs);
// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "OpenRegHooks", OpenRegHooksInterface);

// Device guard registration
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1;

OpenRegGuardImpl() = default;
explicit OpenRegGuardImpl(c10::DeviceType t) {
TORCH_INTERNAL_ASSERT(t == static_type);
}

/**
* Return the type of device managed by this guard implementation.
*/
c10::DeviceType type() const override {
return static_type;
}

/**
* Set the current device to Device, and return the previous c10::Device.
*/
c10::Device exchangeDevice(c10::Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_privateuseone());
py::gil_scoped_acquire acquire;
auto old_device_index = get_method("exchangeDevice")(d.index()).cast<c10::DeviceIndex>();
return c10::Device(static_type, old_device_index);
}

/**
* Get the current device.
*/
c10::Device getDevice() const override {
py::gil_scoped_acquire acquire;
auto device = get_method("getDevice")().cast<c10::DeviceIndex>();
return c10::Device(static_type, device);
}

/**
* Set the current device to c10::Device.
*/
void setDevice(c10::Device d) const override {
TORCH_INTERNAL_ASSERT(d.is_privateuseone());
py::gil_scoped_acquire acquire;
auto device = get_method("setDevice")(d.index());
}

/**
* Set the current device to c10::Device, without checking for errors
* (so, e.g., this can be called from a destructor).
*/
void uncheckedSetDevice(c10::Device d) const noexcept override {
py::gil_scoped_acquire acquire;
auto device = get_method("uncheckedSetDevice")(d.index());
}

/**
* Get the current stream for a given device.
*/
c10::Stream getStream(c10::Device d) const noexcept override {
py::gil_scoped_acquire acquire;
return get_method("getStream")(d.index()).cast<c10::Stream>();
}

/**
* Get the default stream for a given device.
*/
c10::Stream getDefaultStream(c10::Device d) const override {
py::gil_scoped_acquire acquire;
return get_method("getDefaultStream")(d.index()).cast<c10::Stream>();
}

/**
* Get a stream from the global pool for a given device.
*/
c10::Stream getStreamFromGlobalPool(c10::Device d, bool isHighPriority = false) const override {
py::gil_scoped_acquire acquire;
return get_method("getStreamFromGlobalPool")(d.index(), isHighPriority).cast<c10::Stream>();
}

/**
* Return a new stream for a given device and priority. The stream will be
* copied and shared around, device backend should be able to correctly handle
* the lifetime of the stream.
*/
c10::Stream getNewStream(c10::Device d, int priority = 0) const override {
py::gil_scoped_acquire acquire;
return get_method("getNewStream")(d.index(), priority).cast<c10::Stream>();
}

/**
* Set a stream to be the thread local current stream for its device.
* Return the previous stream for that device. You are NOT required
* to set the current device to match the device of this stream.
*/
c10::Stream exchangeStream(c10::Stream s) const noexcept override {
py::gil_scoped_acquire acquire;
return get_method("exchangeStream")(s).cast<c10::Stream>();
}

/**
* Destroys the given event.
*/
void destroyEvent(void* event, const c10::DeviceIndex device_index)
const noexcept override {
py::gil_scoped_acquire acquire;
get_method("destroyEvent")(event, device_index);
}

/**
* Increments the event's version and enqueues a job with this version
* in the stream's work queue. When the stream process that job
* it notifies all streams waiting on / blocked by that version of the
* event to continue and marks that version as recorded.
* */
void record(
void** event,
const c10::Stream& stream,
const c10::DeviceIndex device_index,
const c10::EventFlag flag) const override {
py::gil_scoped_acquire acquire;
get_method("record")(event, stream, device_index, flag);
}

/**
* Does nothing if the event has not been scheduled to be recorded.
* If the event was previously enqueued to be recorded, a command
* to wait for the version of the event that exists at the time of this call
* is inserted in the stream's work queue.
* When the stream reaches this command it will stop processing
* additional commands until that version of the event is marked as recorded.
*/
void block(void* event, const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
get_method("block")(event, stream);
}

/**
* Returns true if (and only if)
* (1) the event has never been scheduled to be recorded
* (2) the current version is marked as recorded.
* Returns false otherwise.
*/
bool queryEvent(void* event) const override {
py::gil_scoped_acquire acquire;
return get_method("queryEvent")(event).cast<bool>();
}

/**
* Get the number of devices. WARNING: This is REQUIRED to not raise
* an exception. If there is some sort of problem, e.g., driver error,
* you should report that there are zero available devices.
*/
c10::DeviceIndex deviceCount() const noexcept override {
py::gil_scoped_acquire acquire;
return get_method("deviceCount")().cast<c10::DeviceIndex>();
}
/**
* Return true if all the work previously enqueued on the stream for
* asynchronous execution has completed running on the device.
*/
bool queryStream(const c10::Stream& stream) const override {
py::gil_scoped_acquire acquire;
return get_method("queryStream")(stream).cast<bool>();
}

/**
* Wait (by blocking the calling thread) until all the work previously
* enqueued on the stream has completed running on the device.
*/
virtual void synchronizeStream(const c10::Stream& stream) const {
py::gil_scoped_acquire acquire;
get_method("synchronizeStream")(stream);
}

/**
* Wait (by blocking the calling thread) until all the work previously
* recorded on the event has completed running on the device.
*/
void synchronizeEvent(void* event) const override {
py::gil_scoped_acquire acquire;
get_method("synchronizeEvent")(event);
}

/**
* Ensure the caching allocator (if any) is aware that the given DataPtr is
* being used on the given stream, and that it should thus avoid recycling the
* DataPtr until all work on that stream is done.
*/
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const c10::Stream& stream)
const override {
py::gil_scoped_acquire acquire;
get_method("recordDataPtrOnStream")(data_ptr, stream);
}

/**
* Fetch the elapsed time between two recorded events.
*/
double elapsedTime(void* event1, void* event2, const c10::DeviceIndex device_index)
const override {
py::gil_scoped_acquire acquire;
return get_method("elapsedTime")(event1, event2, device_index).cast<double>();
}
};

// Register our device guard
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl);

} // anonymous namspaces

// Setter for the python dictionary with implementations
void set_impl_registry(PyObject* registry) {
py_registry = registry;
}
} // openreg
Loading

0 comments on commit 3d87dfc

Please sign in to comment.