-
Notifications
You must be signed in to change notification settings - Fork 2
Developers Guide
Torch-MIGraphX is a python library that aims to integrate MIGraphX into PyTorch workflows as seamlessly as possible. This means being able to consume models that have been built and trained using torch using our APIs and creating a MIGraphX compiled program.
The high level workflow for doing this is as follows:
Below we will use a simple example to explore this full process. The best way to follow is to setup an environment and run the provided code.
Generally a good starting point for working with torch_migraphx is to use a base docker image from rocm/pytorch or rocm/pytorch-nightly.
For development, there is also a dev.Dockerfile in the docker directory of the torch_migraphx repo for convenience. Follow the steps under: Development to setup a container with torch_migraphx in develop mode with all other prerequisites (including MIGraphX) already installed.
From a user perspective, there are two main APIs that allow them to convert a Torch Model to a Torch-MIGraphX model in a single call.
# 'model' is a torch.nn.Module object, and 'sample_inputs' is a list of input tensors in the expected shape of real inputs.
mgx_model = lower_to_mgx(model, sample_inputs)
For full usage examples refer to: FX Examples
mgx_model = torch.compile(model, backend="migraphx")
# Note that the compilation actually happens when the model is first executed and will be recompiled anytime it's called with different input sizes
result = mgx_model(*sample_inputs)
It's good to know how the library is intended to be used but that doesn't tell us much about how it works and how to develop on it. For that let's break down and understand each step in the above workflow diagram using real code.
Before diving into the core torch_migraphx codebase, take some time to understand what a torch model is, how it's created and used. Below is a definition of a simple custom module that we will use to understand the full workflow. Refer to Official PyTorch Docs for a more detailed explanation on the fundamental data structures used in torch.
import torch
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param1 = torch.nn.Parameter(torch.rand(3, 4))
self.param2 = torch.nn.Parameter(torch.rand(3, 5))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
x = x + self.param1
x = self.linear(x)
x = x.mul(self.param2)
return torch.nn.functional.relu(x)
Key things to note here are:
- Generally speaking a "PyTorch Model" is something that is of type (or inherits from)
torch.nn.Module
. A model class can consists of many submodules/layers that are also oftorch.nn.Module
type. In our example we have a model definition of our modelMyModel
which consists of such layers, specificallytorch.nn.Linear
andtorch.nn.ReLU
. These are both also derived from the parent classtorch.nn.Module
-
torch.nn.Module
requires that aforward
method must be defined that determines which operations are performed in what order. In our example we have an inputx
to which we for add a constantparam1
, the result of that is passed through aLinear
layer, the result of that is elementwise multiplied by another constantparam2
and then finally it applies therelu
function which is the output of our model. - This
torch.nn.Module
is callable, meaning that if you want to run this model in eager mode, you can simply pass an input x that will call the forward method.
Create a model using this class definition and run it in eager mode using some random input. "Eager mode" in this context means that the lines in the forward
method will be executed in sequence as defined when the model is called with an input x
. Note that in reality, a model created like this would first have to be trained, but that is not in the scope of MIGraphX, so we will assume that this is a pretrained model.
# Its good practice to set torch models in eval mode for our purposes since some layers behave differently in eval mode vs training mode
mod = MyModule().eval()
in_x = torch.randn(3, 4)
out = mod(in_x)
The model we created in the previous section is nothing more than a Python object. There is no meaningful way a graph optimizer can consume the model in this form, so we need to first transform this into a graph format. For this we use APIs provided within the torch library. There are two methods provided by torch to transform a model into a graphical representation and the data structure used to implement this representation is called torch.fx.GraphModule
.
The first method for generating a graph is to use the Tracer provided by the FX Toolkit. Lets start by tracing our custom module with the base symbolic tracer provided by this toolkit.
from torch.fx import symbolic_trace
symbolic_traced = symbolic_trace(mod)
print(symbolic_traced.graph)
# symbolic_traced.graph.print_tabular() # Feel free to use this style of print if you find it easier to read
graph():
%x : [num_users=1] = placeholder[target=x]
%param1 : [num_users=1] = get_attr[target=param1]
%add : [num_users=1] = call_function[target=operator.add](args = (%x, %param1), kwargs = {})
%linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%param2 : [num_users=1] = get_attr[target=param2]
%mul : [num_users=1] = call_method[target=mul](args = (%linear, %param2), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%mul,), kwargs = {inplace: False})
return relu
Take a moment to relate this graph back to the original module and convince yourself that it is indeed performing the same set of operations on input x just written in graph format. There are a few key things to understand here:
- There are 6 types of nodes here (referred to as opcode)
- placeholder: These are model inputs
- get_attr: These are generally constants (torch parameters are constant in eval mode) or any other model attributes with the name defined by target
- call_function: Calls the target function with
args
andkwargs
- call_method: Calls the target method with
args
andkwargs
- call_module: Calls the target
torch.nn.Module
(ie. the forward method of that module) withargs
andkwargs
- output: This defines the model output(s). In the about format this just shows up as a return statement, but if you use the tabular print function, you will see the opcode listed as output
We will see how we deal with each of these types of nodes when translating to MIGraphX IR, but before that lets consider the model below.
class MyModule2(torch.nn.Module):
def __init__(self, w, b):
super().__init__()
self.param1 = torch.nn.Parameter(torch.rand(3, 4))
self.param2 = torch.nn.Parameter(torch.rand(3, 5))
self.w = torch.nn.Parameter(w)
self.b = torch.nn.Parameter(b)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = torch.add(x, self.param1)
x = torch.nn.functional.linear(x, self.w, self.b)
x = x*self.param2
return self.relu(x)
If you look carefully, this is model is identical in functionally to our original module. Convince yourself further by feeing both models the same input and printing the respective outputs.
mod = MyModule().eval()
# Feed in the same constants to make sure the outputs are comparable
mod2 = MyModule2(mod.linear.weight, mod.linear.bias, mod.param1, mod.param2).eval()
in_x = torch.randn(3, 4)
print(mod(in_x))
print(mod2(in_x))
Lets look at the graph for this second model.
symbolic_traced2 = symbolic_trace(mod2)
print(symbolic_traced2.graph)
# symbolic_traced2.graph.print_tabular()
graph():
%x : [num_users=1] = placeholder[target=x]
%param1 : [num_users=1] = get_attr[target=param1]
%add : [num_users=1] = call_function[target=torch.add](args = (%x, %param1), kwargs = {})
%w : [num_users=1] = get_attr[target=w]
%b : [num_users=1] = get_attr[target=b]
%linear : [num_users=1] = call_function[target=torch._C._nn.linear](args = (%add, %w, %b), kwargs = {})
%param2 : [num_users=1] = get_attr[target=param2]
%mul : [num_users=1] = call_function[target=operator.mul](args = (%linear, %param2), kwargs = {})
%relu : [num_users=1] = call_module[target=relu](args = (%mul,), kwargs = {})
return relu
Notice some key differences:
- The target for the add node changed from
operator.add
totorch.add
- The linear layer is now a
call_function
rather than acall_module
- Similarly the
mul
andrelu
nodes are also different opcodes
Yet the model, mathematically, is the exact same. This can present a lot of duplication when implementing our translation layer, so to deal with these types of variations in torch models, we implement our own derived tracer that can normalize these variations. We call this the acc_tracer
and the relevant code can be found here. Lets trace both these models using our derived tracer.
import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
acc_traced = acc_tracer.trace(mod, [in_x])
print(acc_traced.graph)
# acc_traced.graph.print_tabular()
graph():
%x : [num_users=1] = placeholder[target=x]
%param1 : [num_users=1] = get_attr[target=param1]
%add_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.add](args = (), kwargs = {input: %x, other: %param1})
%linear_weight : [num_users=1] = get_attr[target=linear.weight]
%linear_bias : [num_users=1] = get_attr[target=linear.bias]
%linear_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.linear](args = (), kwargs = {input: %add_1, weight: %linear_weight, bias: %linear_bias})
%param2 : [num_users=1] = get_attr[target=param2]
%mul_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.mul](args = (), kwargs = {input: %linear_1, other: %param2})
%relu_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %mul_1, inplace: False})
return relu_1
As an exercise, try and use this tracer on the second model and see if the graph is the same.
Key things to note here about the normalization we do:
- There are no more
call_module
andcall_method
nodes-
call_module
nodes are replaced by their corresponding functional call because we configure our tracer to trace inside of nested modules. Ie.acc_tracer
traces into thetorch.nn.Linear
module, and finds thetorch.nn.functional.linear
function call within it, in itsforward
method. -
call_method
nodes are mapped to their corresponding functional call
-
- All of the
call_function
targets now point to functions in theacc_ops
namespace - All of the
args
are empty and all function parameters are defined as keyword arguments inkwargs
Take some time to examine few of the function mappings defined in acc_ops.py.
- Mappings are defined using decorators, specifically
register_acc_op_mapping
. This tells our tracer that any node that targets a particular torch method or function should instead target ouracc_ops
function instead. Note that these functions are merely wrappers because we do not want to modify its core functionality, we only want our graph to use these wrappers as thecall_function
targets. - In some cases, instead of defining a new
acc_op
, we want to simple replace an op by writing it in terms of another op (which is usually a more generalized version of the same op). For example, examine thetranspose_mapper
: instead of creating a newacc_op
called transpose, we simply remap this node to a permute node that performs this same operation. Now we only have to worry about implementing a translator for the permute operation. Such remappings are defined by used the decoratorregister_custom_acc_mapper_fn
.
At this point the torch graph is ready to be translated to MIGraphX instructions. Before we explore the translation process, we will look at the other method provided for generating functional graphs in torch.
This is a feature available with the release of PyTorch 2.0 and is a feature that is heavily worked on currently. This feature is actually intended to support model compilation natively using torch, but we can use the features it provides to instead use MIGraphX to perform the compilation. The main advantage of using this method over FX Tracing is that this toolkit allows compiling models that have data-dependent control flow (ie. we need to execute different sets of ops depending on the values of tensors at runtime). This is not allowed in FX Tracing, as the tracer will fail when it encounters such models. Read more about this approch in the official docs. We will use the same model from the previous section to explore this approach. Let's start by defining a custom backend that we can use to understand how the torch.compile API works.
import torch._dynamo as dynamo
from torch._functorch.aot_autograd import aot_export_joint_simple
@dynamo.register_backend(name="my_backend")
def test_backend(gm, example_inputs, **kwargs):
TracingContext.get().fake_mode.allow_non_fake_inputs = True
print(example_inputs)
print(gm.graph)
aten_gm = aot_export_joint_simple(gm, example_inputs, trace_joint=False)
print(aten_gm.graph)
return aten_gm
There is a lot to understand in terms of all the underlying mechanisms that are employed by a torch.compile
call, but for us, we can focus on understanding how a backend is defined. The dynamo.register_backend
decorator is what tells dynamo where to look for the definition of "my_backend" in the call below.
mod = MyModule().eval()
in_x = torch.randn(3, 4)
mod_dynamo = torch.compile(mod, backend="my_backend")
mod_dynamo(in_x) # This line is when the test_backend function is invoked
[tensor([[ 1.0750, -0.4972, 0.7909, 0.1489],
[-0.6334, -0.4037, -0.3144, 1.3126],
[ 0.6391, -1.0924, 0.6623, 0.5520]])]
graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%l__self___param1 : [num_users=1] = get_attr[target=L__self___param1]
%add : [num_users=1] = call_function[target=operator.add](args = (%l_x_, %l__self___param1), kwargs = {})
%l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%add,), kwargs = {})
%l__self___param2 : [num_users=1] = get_attr[target=L__self___param2]
%mul : [num_users=1] = call_method[target=mul](args = (%l__self___linear, %l__self___param2), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%mul,), kwargs = {})
return (relu,)
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %_param_constant0), kwargs = {})
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant1,), kwargs = {})
%_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant2, %add, %t), kwargs = {})
%_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %_param_constant3), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mul,), kwargs = {})
return (relu,)
Let's highlight some key observations:
-
example_inputs
is a list of tensors (even though our model call wasmod_dynamo(in_x)
, the tensorin_x
is added to the example input list- If our model had multiple inputs (eg.
mod_dynamo(in_x, in_y, in_z)
), theexample_inputs
tensor would be a list of 3 tensors corresponding to the 3 inputs
- If our model had multiple inputs (eg.
- The dynamo machinery passes in a
GraphModule
object to our backend (which we reference asgm
).- This object consists of nodes that point to internal methods and modules and is not really in a format this can be translated
- We use the
aot_export_joint_simple
API provided in thefunctorch
toolkit to export a graph the is reduced to function calls that are in thetorch.ops.aten
namespace- This is similar to the
acc_ops
normalization that we did in FX Tracing where all operations are written ascall_function
nodes with targets in a single namespace.
- This is similar to the
- In this toy example, our backend just does this export, prints the graph and then returns the exported torch
GraphModule
. In an actual backend implementation this return is expected to be a different Python callable that takes inputs that are identical toexample_inputs
in terms of shape and datatypes.
It's highly recommended going through this specific section of the official torch tutorial as it shows what happens when there is data dependent control flow in the model, and how this backend function can be invoked multiple times for each subgraph that the control flow can feed into.
To return something meaningful (ie. a compiled program) from this backend, we need to actually translate these graphs to MIGraphX programs.
So far, we've merely used torch APIs to get to a point where we have graphical representations of torch models in a consistent format. Now we can dive into the actual conversion mechanism for generating MIGraphX programs from these torch graphs. At this point it is more difficult to explore with toy code and so we will examine relevant code in our real codebase.
The way we translate the torch graph is relatively straight forward but requires a good understanding of the torch.fx.Interpreter class. This class is designed to traverse a GraphModule
node by node and perform any required "transformations". In our case these "transformations" will just mean adding corresponding instructions to our migraphx program
that implement the equivalent functionality to that of the torch node. This will become clear as we walkthrough the conversion of our example model below.
Our FX Interpreter is defined in fx2mgx.py. Keep in mind that this class will be used to iterate through torch graphs node by node in order. Note some key things about what's happening in this class:
- On initialization, we create an empty migraphx program (using migraphx's python API). We will be adding instructions to this program as we traverse the nodes
- Calling the
run
method initiates the node traversal - The methods
placeholder
,call_module
,call_function
,call_method
,get_attr
, andoutput
defines what happens when each of these types of nodes encountered during the traversal
The placeholder
, get_attr
, and output
methods are straightforward
-
placeholder
adds an input to the migraphx program (using the add_parameter migraphx function) -
get_attr
adds literals to the migraphx program (using add_literal migraphx function) -
output
adds outputs to the migraphx program (using add_return migraphx function)
The call_module
, call_function
, and call_method
methods perform math/tensor operations that need to be translated to operations defined in MIGraphX. To maintain the converters we have implemented, we have a CONVERTERS
dictionary.
- The keys of this dictionary are functions (specifically
acc_ops
andaten
functions that we saw as our targets incall_function
nodes in the previous section). - The values in this dictionary are also functions. These are the converter functions that define how a
acc_ops
oraten
function should be translated to migraphx instructions.
Converter functions are defined in acc_ops_converters.py and aten_ops_converters.py. Note that there are some cases where functions/modules are not normalized to acc_ops and so there are some additional converters defined in the converters directory. For now we will look at the FX traced graph and focus on acc_ops_converters.py. The CONVERTERS dictionary is populated by the decorators that are applied to each of the converter functions in these files.
Here we will pass the acc_traced
module from the previous section through the interpreter and understand the migraphx program that is generated.
mod = MyModule().eval()
in_x = torch.randn(3, 4)
import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
acc_traced = acc_tracer.trace(mod, [in_x])
print(acc_traced.graph)
# acc_traced.graph.print_tabular()
from torch_migraphx.fx.fx2mgx import MGXInterpreter
interp = MGXInterpreter(acc_traced, [in_x])
interp.run()
print(interp.program)
graph():
- %x : [num_users=1] = placeholder[target=x]
%param1 : [num_users=1] = get_attr[target=param1]
! %add_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.add](args = (), kwargs = {input: %x, other: %param1})
%linear_weight : [num_users=1] = get_attr[target=linear.weight]
%linear_bias : [num_users=1] = get_attr[target=linear.bias]
+ %linear_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.linear](args = (), kwargs = {input: %add_1, weight: %linear_weight, bias: %linear_bias})
%param2 : [num_users=1] = get_attr[target=param2]
# %mul_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.mul](args = (), kwargs = {input: %linear_1, other: %param2})
@@ %relu_1 : [num_users=1] = call_function[target=torch_migraphx.fx.tracer.acc_tracer.acc_ops.relu](args = (), kwargs = {input: %mul_1, inplace: False}) @@
return relu_1
module: "main"
@0 = @literal{ ... } -> float_type, {3, 5}, {5, 1}
@1 = @literal{-0.0148876, -0.224566, 0.488997, -0.449717, 0.486762} -> float_type, {5}, {1}
@2 = @literal{ ... } -> float_type, {5, 4}, {4, 1}
@3 = @literal{ ... } -> float_type, {3, 4}, {4, 1}
- x = @param:x -> float_type, {3, 4}, {4, 1}
! @5 = add(x,@3) -> float_type, {3, 4}, {4, 1}
+ @6 = transpose[permutation={1, 0}](@2) -> float_type, {4, 5}, {1, 4}
+ @7 = multibroadcast[out_lens={4, 5},out_dyn_dims={}](@6) -> float_type, {4, 5}, {1, 4}
+ @8 = dot(@5,@7) -> float_type, {3, 5}, {5, 1}
+ @9 = multibroadcast[out_lens={3, 5},out_dyn_dims={}](@1) -> float_type, {3, 5}, {0, 1}
+ @10 = add(@8,@9) -> float_type, {3, 5}, {5, 1}
# @11 = mul(@10,@0) -> float_type, {3, 5}, {5, 1}
@@ @12 = relu(@11) -> float_type, {3, 5}, {5, 1} @@
@13 = @return(@12)
Take some time to understand which converters are called and how each of the instructions is added in the generated migraphx program. The output above is manually color coded so make sure you fully understand how each torch node is translated to an migraphx instruction, or a set of migraphx instructions. Understanding this is key to implementing converters and contributing to this codebase.
Studying the implemented converters, you can note a few key things:
- Converters can be trivial where it is simply a one-to-one mapping to an migraphx op (eg.
relu
) - Some are one-to-one mappings but need to account for the fact that torch allows implicit broadcasting (eg,
add
,mul
)- In our particular example none of the operands for
add
andmul
actually need broadcasting but in general this is allowed by torch and so it must be handled
- In our particular example none of the operands for
- Some converters can require a series of migraphx ops to implement the equivalent functionality (eg. linear)
Here is a modified backend definition that adds the interpreter into the mix.
import torch._dynamo as dynamo
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch._guards import TracingContext
from torch_migraphx.fx.fx2mgx import MGXInterpreter
@dynamo.register_backend(name="my_backend")
def test_backend(gm, example_inputs, **kwargs):
TracingContext.get().fake_mode.allow_non_fake_inputs = True
aten_gm = aot_export_joint_simple(gm, example_inputs, trace_joint=False)
print(aten_gm.graph)
interp = MGXInterpreter(aten_gm, example_inputs)
interp.run()
print(interp.program)
return aten_gm
mod_dynamo = torch.compile(mod, backend="my_backend")
mod_dynamo(in_x)
graph():
%arg0_1 : [num_users=1] = placeholder[target=arg0_1]
%_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %_param_constant0), kwargs = {})
%_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
%t : [num_users=1] = call_function[target=torch.ops.aten.t.default](args = (%_param_constant1,), kwargs = {})
%_param_constant2 : [num_users=1] = get_attr[target=_param_constant2]
%addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant2, %add, %t), kwargs = {})
%_param_constant3 : [num_users=1] = get_attr[target=_param_constant3]
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%addmm, %_param_constant3), kwargs = {})
%relu : [num_users=1] = call_function[target=torch.ops.aten.relu.default](args = (%mul,), kwargs = {})
return (relu,)
module: "main"
@0 = @literal{ ... } -> float_type, {3, 5}, {5, 1}
@1 = @literal{-0.298655, -0.145471, -0.311266, -0.279463, -0.389797} -> float_type, {5}, {1}
@2 = @literal{ ... } -> float_type, {5, 4}, {4, 1}
@3 = @literal{ ... } -> float_type, {3, 4}, {4, 1}
arg0_1 = @param:arg0_1 -> float_type, {3, 4}, {4, 1}
@5 = add(arg0_1,@3) -> float_type, {3, 4}, {4, 1}
@6 = transpose[permutation={1, 0}](@2) -> float_type, {4, 5}, {1, 4}
@7 = multibroadcast[out_lens={3, 4},out_dyn_dims={}](@5) -> float_type, {3, 4}, {4, 1}
@8 = multibroadcast[out_lens={4, 5},out_dyn_dims={}](@6) -> float_type, {4, 5}, {1, 4}
@9 = dot(@7,@8) -> float_type, {3, 5}, {5, 1}
@10 = multibroadcast[out_lens={3, 5},out_dyn_dims={}](@1) -> float_type, {3, 5}, {0, 1}
@11 = multibroadcast[out_lens={3, 5},out_dyn_dims={}](@9) -> float_type, {3, 5}, {5, 1}
@12 = add(@10,@11) -> float_type, {3, 5}, {5, 1}
@13 = mul(@12,@0) -> float_type, {3, 5}, {5, 1}
@14 = relu(@13) -> float_type, {3, 5}, {5, 1}
@15 = @return(@14)
Here, the nodes have not been color coded so it is a very good exercise to go through the torch graph and identify which migraphx instructions correspond to each node by examining the associated converters. Some notable things to highlight about aten converters are:
- They ALWAYS point to an
acc_ops
converter- This is because in PyTorch,
aten
ops are low-level implementations of the high-level torch API functions. For example, the high-level functiontorch.add
will actually call theaten.add
function "under the hood". This means that if we have support (ie. a converter exists) for a high-level op, we can just point these low-level ops to existing converters with the right arguments.
- This is because in PyTorch,
- In general
aten
ops are implemented in C++ in torch, and don't usually support keyword arguments in the same way as the high-level ops. So in aten converters we have to rely on args and so the order of inputs arguments is important
We're almost at the finish line. We have a migraphx program, now all that's left to do is let MIGraphX do its magic on this parsed program and generate a compiled version of this program so we can execute it. Before diving into the implementation, lets complete the workflow for MyModule and see how to execute our custom model using migraphx.
mod = MyModule().eval()
in_x = torch.randn(3, 4)
out = mod(in_x)
import torch_migraphx.fx.tracer.acc_tracer.acc_tracer as acc_tracer
acc_traced = acc_tracer.trace(mod, [in_x])
# print(acc_traced.graph)
# acc_traced.graph.print_tabular()
from torch_migraphx.fx.fx2mgx import MGXInterpreter
interp = MGXInterpreter(acc_traced, [in_x])
interp.run()
# print(interp.program)
from torch_migraphx.fx.mgx_module import MGXModule
mgx_mod = MGXModule(interp.program, interp.get_input_names())
mgx_out = mgx_mod(in_x)
print(out)
print(mgx_out)
Run this and verify that the two outputs are the same. Also, at this point circle back to the usage section and examine the 2 entrypoints listed there. This above code block is a simplified version of what happens when those entrypoints (lower_to_mgx
and torch.compile
) are used. There are a number of details that we ignore in this walkthrough for the sake of simplicity, but now you understand the core components of the pipeline that is invoked by those calls.
Let's examine the implementation of this final piece of the workflow. The MGXModule class is implemented in mgx_module.py. Here are the core features of this class that you should look for in the code:
-
MGXModule
inherits fromtorch.nn.Module
which allows objects of this class to be executed in the same manner as normal torch models (notice how we invokemgx_mod
in the same way asmod
) - When a
MGXModule
object is initialized with a program (from the interpreter):- Compile the program using the
program.compile
call provided by the migraphx python API - Output buffers are also allocated so that they can be passed as parameters when we run the program (this allows us the keep the output tensor on the gpu)
- Compile the program using the
- The code for actually running this model resides in the forward method as for all
torch.nn.Module
objects. Important details to note:- The
run_async
call is used to avoid unnecessary syncronizations - The stream used for this async call is the default PyTorch stream. This is an important detail that prevents race conditions from happening when inputs to the
MGXModule
are outputs from other torch models, or vice versa. Eg. if a user is running a workflow where there is a series of models mod1 → mod2 → mod3 where mod1 and mod3 are regular torch models, but mod2 is aMGXModule
, then an async call on different streams may not wait for outputs from mod1, and similarly mod3 will not wait for migraphx to write to the output buffers defined in mod2.
- The
- There are some additional functions implemented to allow
MGXModule
objects to be saved in the same manner as normal torch models