forked from ROCm/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add basic OpenReg module scaffolding with autograd (pytorch#131708)
Pull Request resolved: pytorch#131708 Approved by: https://github.com/ezyang
- Loading branch information
1 parent
df59084
commit 3d87dfc
Showing
8 changed files
with
435 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core. | ||
|
||
## How to use | ||
Install as standalone with `python setup.py develop` (or install) from this folder. | ||
You can run test via `python test/test_openreg.py`. | ||
|
||
## Design principles | ||
For simplicity anything that can be implemented from python is done so. | ||
A real implementation will most likely want to call these different APIs from c++ directly. | ||
|
||
The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn. | ||
|
||
Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate. |
45 changes: 45 additions & 0 deletions
45
test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import torch | ||
|
||
|
||
# Global properties of our device | ||
NUM_DEVICES = 7 | ||
|
||
# Create our python implementation dict so that the C++ module | ||
# can access it during its initialization | ||
_IMPL_REGISTRY = {} | ||
|
||
# Load the C++ Module | ||
import pytorch_openreg._C # noqa: F401 | ||
|
||
|
||
# Define all the implementations in the registry | ||
def register(fn): | ||
_IMPL_REGISTRY[fn.__name__[1:]] = fn | ||
return fn | ||
|
||
|
||
@register | ||
def _deviceCount(): | ||
return NUM_DEVICES | ||
|
||
|
||
# Module used for our backend | ||
class _OpenRegMod: | ||
pass | ||
|
||
|
||
# Set all the appropriate state on PyTorch | ||
torch.utils.rename_privateuse1_backend("openreg") | ||
torch._register_device_module("openreg", _OpenRegMod()) | ||
|
||
_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901 | ||
|
||
|
||
def _openreg_kernel_fallback(op, *args, **kwargs): | ||
print("Calling ", op) | ||
assert op is torch.ops.aten.empty.memory_format | ||
# FIXME: this returns a cpu Tensor which is NOT ok. | ||
return torch.empty(args[0]) | ||
|
||
|
||
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") |
17 changes: 17 additions & 0 deletions
17
test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/Module.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#include <OpenReg.h> | ||
|
||
// Make this a proper CPython module | ||
static struct PyModuleDef openreg_C_module = { | ||
PyModuleDef_HEAD_INIT, | ||
.m_name = "pytorch_openreg._C", | ||
}; | ||
|
||
PyMODINIT_FUNC PyInit__C(void) { | ||
PyObject* mod = PyModule_Create(&openreg_C_module); | ||
|
||
py::object openreg_mod = py::module_::import("pytorch_openreg"); | ||
// Only borrowed from the python side! | ||
openreg::set_impl_registry(openreg_mod.attr("_IMPL_REGISTRY").ptr()); | ||
|
||
return mod; | ||
} |
10 changes: 10 additions & 0 deletions
10
test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenReg.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#pragma once | ||
// Shared header for OpenReg module | ||
|
||
#include <torch/csrc/utils/pybind.h> | ||
|
||
namespace openreg { | ||
|
||
void set_impl_registry(PyObject* registry); | ||
|
||
} |
246 changes: 246 additions & 0 deletions
246
test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
#include <OpenReg.h> | ||
|
||
#include <c10/core/impl/DeviceGuardImplInterface.h> | ||
#include <ATen/detail/PrivateUse1HooksInterface.h> | ||
|
||
#include <iostream> | ||
|
||
namespace openreg { | ||
|
||
namespace { | ||
// Python dictionary where real implementations can be found | ||
PyObject* py_registry; | ||
|
||
py::function get_method(const char* name) { | ||
return py::cast<py::dict>(py_registry)[name]; | ||
} | ||
|
||
// C++ hooks implementation | ||
struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {}; | ||
|
||
struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { | ||
OpenRegHooksInterface(OpenRegHooksArgs) {}; | ||
~OpenRegHooksInterface() override = default; | ||
|
||
bool hasPrimaryContext(c10::DeviceIndex device_index) const override { | ||
return get_method("hasPrimaryContext")(device_index).cast<bool>(); | ||
} | ||
}; | ||
|
||
TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs); | ||
C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs); | ||
// Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class. | ||
C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "OpenRegHooks", OpenRegHooksInterface); | ||
|
||
// Device guard registration | ||
struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { | ||
static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; | ||
|
||
OpenRegGuardImpl() = default; | ||
explicit OpenRegGuardImpl(c10::DeviceType t) { | ||
TORCH_INTERNAL_ASSERT(t == static_type); | ||
} | ||
|
||
/** | ||
* Return the type of device managed by this guard implementation. | ||
*/ | ||
c10::DeviceType type() const override { | ||
return static_type; | ||
} | ||
|
||
/** | ||
* Set the current device to Device, and return the previous c10::Device. | ||
*/ | ||
c10::Device exchangeDevice(c10::Device d) const override { | ||
TORCH_INTERNAL_ASSERT(d.is_privateuseone()); | ||
py::gil_scoped_acquire acquire; | ||
auto old_device_index = get_method("exchangeDevice")(d.index()).cast<c10::DeviceIndex>(); | ||
return c10::Device(static_type, old_device_index); | ||
} | ||
|
||
/** | ||
* Get the current device. | ||
*/ | ||
c10::Device getDevice() const override { | ||
py::gil_scoped_acquire acquire; | ||
auto device = get_method("getDevice")().cast<c10::DeviceIndex>(); | ||
return c10::Device(static_type, device); | ||
} | ||
|
||
/** | ||
* Set the current device to c10::Device. | ||
*/ | ||
void setDevice(c10::Device d) const override { | ||
TORCH_INTERNAL_ASSERT(d.is_privateuseone()); | ||
py::gil_scoped_acquire acquire; | ||
auto device = get_method("setDevice")(d.index()); | ||
} | ||
|
||
/** | ||
* Set the current device to c10::Device, without checking for errors | ||
* (so, e.g., this can be called from a destructor). | ||
*/ | ||
void uncheckedSetDevice(c10::Device d) const noexcept override { | ||
py::gil_scoped_acquire acquire; | ||
auto device = get_method("uncheckedSetDevice")(d.index()); | ||
} | ||
|
||
/** | ||
* Get the current stream for a given device. | ||
*/ | ||
c10::Stream getStream(c10::Device d) const noexcept override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("getStream")(d.index()).cast<c10::Stream>(); | ||
} | ||
|
||
/** | ||
* Get the default stream for a given device. | ||
*/ | ||
c10::Stream getDefaultStream(c10::Device d) const override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("getDefaultStream")(d.index()).cast<c10::Stream>(); | ||
} | ||
|
||
/** | ||
* Get a stream from the global pool for a given device. | ||
*/ | ||
c10::Stream getStreamFromGlobalPool(c10::Device d, bool isHighPriority = false) const override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("getStreamFromGlobalPool")(d.index(), isHighPriority).cast<c10::Stream>(); | ||
} | ||
|
||
/** | ||
* Return a new stream for a given device and priority. The stream will be | ||
* copied and shared around, device backend should be able to correctly handle | ||
* the lifetime of the stream. | ||
*/ | ||
c10::Stream getNewStream(c10::Device d, int priority = 0) const override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("getNewStream")(d.index(), priority).cast<c10::Stream>(); | ||
} | ||
|
||
/** | ||
* Set a stream to be the thread local current stream for its device. | ||
* Return the previous stream for that device. You are NOT required | ||
* to set the current device to match the device of this stream. | ||
*/ | ||
c10::Stream exchangeStream(c10::Stream s) const noexcept override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("exchangeStream")(s).cast<c10::Stream>(); | ||
} | ||
|
||
/** | ||
* Destroys the given event. | ||
*/ | ||
void destroyEvent(void* event, const c10::DeviceIndex device_index) | ||
const noexcept override { | ||
py::gil_scoped_acquire acquire; | ||
get_method("destroyEvent")(event, device_index); | ||
} | ||
|
||
/** | ||
* Increments the event's version and enqueues a job with this version | ||
* in the stream's work queue. When the stream process that job | ||
* it notifies all streams waiting on / blocked by that version of the | ||
* event to continue and marks that version as recorded. | ||
* */ | ||
void record( | ||
void** event, | ||
const c10::Stream& stream, | ||
const c10::DeviceIndex device_index, | ||
const c10::EventFlag flag) const override { | ||
py::gil_scoped_acquire acquire; | ||
get_method("record")(event, stream, device_index, flag); | ||
} | ||
|
||
/** | ||
* Does nothing if the event has not been scheduled to be recorded. | ||
* If the event was previously enqueued to be recorded, a command | ||
* to wait for the version of the event that exists at the time of this call | ||
* is inserted in the stream's work queue. | ||
* When the stream reaches this command it will stop processing | ||
* additional commands until that version of the event is marked as recorded. | ||
*/ | ||
void block(void* event, const c10::Stream& stream) const override { | ||
py::gil_scoped_acquire acquire; | ||
get_method("block")(event, stream); | ||
} | ||
|
||
/** | ||
* Returns true if (and only if) | ||
* (1) the event has never been scheduled to be recorded | ||
* (2) the current version is marked as recorded. | ||
* Returns false otherwise. | ||
*/ | ||
bool queryEvent(void* event) const override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("queryEvent")(event).cast<bool>(); | ||
} | ||
|
||
/** | ||
* Get the number of devices. WARNING: This is REQUIRED to not raise | ||
* an exception. If there is some sort of problem, e.g., driver error, | ||
* you should report that there are zero available devices. | ||
*/ | ||
c10::DeviceIndex deviceCount() const noexcept override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("deviceCount")().cast<c10::DeviceIndex>(); | ||
} | ||
/** | ||
* Return true if all the work previously enqueued on the stream for | ||
* asynchronous execution has completed running on the device. | ||
*/ | ||
bool queryStream(const c10::Stream& stream) const override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("queryStream")(stream).cast<bool>(); | ||
} | ||
|
||
/** | ||
* Wait (by blocking the calling thread) until all the work previously | ||
* enqueued on the stream has completed running on the device. | ||
*/ | ||
virtual void synchronizeStream(const c10::Stream& stream) const { | ||
py::gil_scoped_acquire acquire; | ||
get_method("synchronizeStream")(stream); | ||
} | ||
|
||
/** | ||
* Wait (by blocking the calling thread) until all the work previously | ||
* recorded on the event has completed running on the device. | ||
*/ | ||
void synchronizeEvent(void* event) const override { | ||
py::gil_scoped_acquire acquire; | ||
get_method("synchronizeEvent")(event); | ||
} | ||
|
||
/** | ||
* Ensure the caching allocator (if any) is aware that the given DataPtr is | ||
* being used on the given stream, and that it should thus avoid recycling the | ||
* DataPtr until all work on that stream is done. | ||
*/ | ||
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const c10::Stream& stream) | ||
const override { | ||
py::gil_scoped_acquire acquire; | ||
get_method("recordDataPtrOnStream")(data_ptr, stream); | ||
} | ||
|
||
/** | ||
* Fetch the elapsed time between two recorded events. | ||
*/ | ||
double elapsedTime(void* event1, void* event2, const c10::DeviceIndex device_index) | ||
const override { | ||
py::gil_scoped_acquire acquire; | ||
return get_method("elapsedTime")(event1, event2, device_index).cast<double>(); | ||
} | ||
}; | ||
|
||
// Register our device guard | ||
C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); | ||
|
||
} // anonymous namspaces | ||
|
||
// Setter for the python dictionary with implementations | ||
void set_impl_registry(PyObject* registry) { | ||
py_registry = registry; | ||
} | ||
} // openreg |
Oops, something went wrong.