forked from pytorch/extension-cpp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit c081299
Showing
17 changed files
with
721 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# pytorch-cpp-extension | ||
|
||
An example of writing a C++ extension for PyTorch. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import math | ||
import time | ||
|
||
import torch | ||
|
||
TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000} | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('example', choices=['py', 'cpp', 'cuda']) | ||
parser.add_argument('-b', '--batch-size', type=int, default=16) | ||
parser.add_argument('-f', '--features', type=int, default=32) | ||
parser.add_argument('-s', '--state-size', type=int, default=128) | ||
parser.add_argument('-r', '--runs', type=int, default=100) | ||
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us') | ||
parser.add_argument('-c', '--cuda', action='store_true') | ||
options = parser.parse_args() | ||
|
||
if options.example == 'py': | ||
from python.lltm import LLTM | ||
elif options.example == 'cpp': | ||
from cpp.lltm import LLTM | ||
else: | ||
from cuda.lltm import LLTM | ||
options.cuda = True | ||
|
||
X = torch.randn(options.batch_size, options.features) | ||
h = torch.randn(options.batch_size, options.state_size) | ||
C = torch.randn(options.batch_size, options.state_size) | ||
rnn = LLTM(options.features, options.state_size) | ||
|
||
if options.cuda: | ||
X = X.cuda() | ||
h = h.cuda() | ||
C = C.cuda() | ||
rnn.cuda() | ||
|
||
# Force CUDA initialization | ||
new_h, new_C = rnn(X, (h, C)) | ||
(new_h.sum() + new_C.sum()).backward() | ||
|
||
forward_min = math.inf | ||
forward_time = 0 | ||
backward_min = math.inf | ||
backward_time = 0 | ||
for _ in range(options.runs): | ||
rnn.zero_grad() | ||
|
||
start = time.time() | ||
new_h, new_C = rnn(X, (h, C)) | ||
elapsed = time.time() - start | ||
forward_min = min(forward_min, elapsed) | ||
forward_time += elapsed | ||
|
||
start = time.time() | ||
(new_h.sum() + new_C.sum()).backward() | ||
elapsed = time.time() - start | ||
backward_min = min(backward_min, elapsed) | ||
backward_time += elapsed | ||
|
||
scale = TIME_SCALES[options.scale] | ||
forward_min *= scale | ||
backward_min *= scale | ||
forward_average = forward_time / options.runs * scale | ||
backward_average = backward_time / options.runs * scale | ||
|
||
print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format( | ||
forward_min, forward_average, backward_min, backward_average, | ||
options.scale)) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from torch.utils.cpp_extension import load | ||
lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"], verbose=True) | ||
help(lltm_cpp) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#include <torch/torch.h> | ||
|
||
#include <vector> | ||
|
||
// s'(z) = (1 - s(z)) * s(z) | ||
at::Tensor d_sigmoid(at::Tensor z) { | ||
auto s = at::sigmoid(z); | ||
return (1 - s) * s; | ||
} | ||
|
||
// tanh'(z) = 1 - tanh^2(z) | ||
at::Tensor d_tanh(at::Tensor z) { | ||
return 1 - z.tanh().pow(2); | ||
} | ||
|
||
// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} | ||
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) { | ||
auto e = z.exp(); | ||
auto mask = (alpha * (e - 1)) < 0; | ||
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); | ||
} | ||
|
||
std::vector<at::Tensor> lltm_forward( | ||
at::Tensor input, | ||
at::Tensor weights, | ||
at::Tensor bias, | ||
at::Tensor old_h, | ||
at::Tensor old_cell) { | ||
auto X = at::cat({old_h, input}, /*dim=*/1); | ||
|
||
auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1)); | ||
auto gates = gate_weights.chunk(3, /*dim=*/1); | ||
|
||
auto input_gate = at::sigmoid(gates[0]); | ||
auto output_gate = at::sigmoid(gates[1]); | ||
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0); | ||
|
||
auto new_cell = old_cell + candidate_cell * input_gate; | ||
auto new_h = at::tanh(new_cell) * output_gate; | ||
|
||
return {new_h, | ||
new_cell, | ||
input_gate, | ||
output_gate, | ||
candidate_cell, | ||
X, | ||
gate_weights}; | ||
} | ||
|
||
std::vector<at::Tensor> lltm_backward( | ||
at::Tensor grad_h, | ||
at::Tensor grad_cell, | ||
at::Tensor new_cell, | ||
at::Tensor input_gate, | ||
at::Tensor output_gate, | ||
at::Tensor candidate_cell, | ||
at::Tensor X, | ||
at::Tensor gate_weights, | ||
at::Tensor weights) { | ||
auto d_output_gate = at::tanh(new_cell) * grad_h; | ||
auto d_tanh_new_cell = output_gate * grad_h; | ||
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; | ||
|
||
auto d_old_cell = d_new_cell; | ||
auto d_candidate_cell = input_gate * d_new_cell; | ||
auto d_input_gate = candidate_cell * d_new_cell; | ||
|
||
auto gates = gate_weights.chunk(3, /*dim=*/1); | ||
d_input_gate *= d_sigmoid(gates[0]); | ||
d_output_gate *= d_sigmoid(gates[1]); | ||
d_candidate_cell *= d_elu(gates[2]); | ||
|
||
auto d_gates = | ||
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); | ||
|
||
auto d_weights = d_gates.t().mm(X); | ||
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); | ||
|
||
auto d_X = d_gates.mm(weights); | ||
const auto state_size = grad_h.size(1); | ||
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size); | ||
auto d_input = d_X.slice(/*dim=*/1, state_size); | ||
|
||
return {d_old_h, d_input, d_weights, d_bias, d_old_cell}; | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", &lltm_forward, "LLTM forward"); | ||
m.def("backward", &lltm_backward, "LLTM backward"); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import math | ||
from torch import nn | ||
from torch.autograd import Function | ||
import torch | ||
|
||
import lltm_cpp | ||
|
||
torch.manual_seed(42) | ||
|
||
|
||
class LLTMFunction(Function): | ||
@staticmethod | ||
def forward(ctx, input, weights, bias, old_h, old_cell): | ||
outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell) | ||
new_h, new_cell = outputs[:2] | ||
variables = outputs[1:] + [weights] | ||
ctx.save_for_backward(*variables) | ||
|
||
return new_h, new_cell | ||
|
||
@staticmethod | ||
def backward(ctx, grad_h, grad_cell): | ||
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cpp.backward( | ||
grad_h, grad_cell, *ctx.saved_variables) | ||
return d_input, d_weights, d_bias, d_old_h, d_old_cell | ||
|
||
|
||
class LLTM(nn.Module): | ||
def __init__(self, input_features, state_size): | ||
super(LLTM, self).__init__() | ||
self.input_features = input_features | ||
self.state_size = state_size | ||
self.weights = nn.Parameter( | ||
torch.Tensor(3 * state_size, input_features + state_size)) | ||
self.bias = nn.Parameter(torch.Tensor(3 * state_size)) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
stdv = 1.0 / math.sqrt(self.state_size) | ||
for weight in self.parameters(): | ||
weight.data.uniform_(-stdv, +stdv) | ||
|
||
def forward(self, input, state): | ||
return LLTMFunction.apply(input, self.weights, self.bias, *state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from setuptools import setup | ||
from torch.utils.cpp_extension import BuildExtension, CppExtension | ||
|
||
setup( | ||
name='lltm_cpp', | ||
ext_modules=[ | ||
CppExtension('lltm_cpp', ['lltm.cpp']), | ||
], | ||
cmdclass={ | ||
'build_ext': BuildExtension | ||
}) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from torch.utils.cpp_extension import load | ||
lltm_cuda = load( | ||
'lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True) | ||
help(lltm_cuda) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import math | ||
from torch import nn | ||
from torch.autograd import Function | ||
import torch | ||
|
||
import lltm_cuda | ||
|
||
torch.manual_seed(42) | ||
|
||
|
||
class LLTMFunction(Function): | ||
@staticmethod | ||
def forward(ctx, input, weights, bias, old_h, old_cell): | ||
outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell) | ||
new_h, new_cell = outputs[:2] | ||
variables = outputs[1:] + [weights] | ||
ctx.save_for_backward(*variables) | ||
|
||
return new_h, new_cell | ||
|
||
@staticmethod | ||
def backward(ctx, grad_h, grad_cell): | ||
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cuda.backward( | ||
grad_h, grad_cell, *ctx.saved_variables) | ||
return d_input, d_weights, d_bias, d_old_h, d_old_cell | ||
|
||
|
||
class LLTM(nn.Module): | ||
def __init__(self, input_features, state_size): | ||
super(LLTM, self).__init__() | ||
self.input_features = input_features | ||
self.state_size = state_size | ||
self.weights = nn.Parameter( | ||
torch.Tensor(3 * state_size, input_features + state_size)) | ||
self.bias = nn.Parameter(torch.Tensor(3 * state_size)) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
stdv = 1.0 / math.sqrt(self.state_size) | ||
for weight in self.parameters(): | ||
weight.data.uniform_(-stdv, +stdv) | ||
|
||
def forward(self, input, state): | ||
return LLTMFunction.apply(input, self.weights, self.bias, *state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#include <torch/torch.h> | ||
|
||
#include <vector> | ||
|
||
// CUDA forward declarations | ||
|
||
std::vector<at::Tensor> lltm_cuda_forward( | ||
at::Tensor input, | ||
at::Tensor weights, | ||
at::Tensor bias, | ||
at::Tensor old_h, | ||
at::Tensor old_cell); | ||
|
||
std::vector<at::Tensor> lltm_cuda_backward( | ||
at::Tensor grad_h, | ||
at::Tensor grad_cell, | ||
at::Tensor new_cell, | ||
at::Tensor input_gate, | ||
at::Tensor output_gate, | ||
at::Tensor candidate_cell, | ||
at::Tensor X, | ||
at::Tensor gate_weights, | ||
at::Tensor weights); | ||
|
||
// C++ interface | ||
|
||
#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") | ||
|
||
std::vector<at::Tensor> lltm_forward( | ||
at::Tensor input, | ||
at::Tensor weights, | ||
at::Tensor bias, | ||
at::Tensor old_h, | ||
at::Tensor old_cell) { | ||
CHECK_CUDA(input); | ||
CHECK_CUDA(weights); | ||
CHECK_CUDA(bias); | ||
CHECK_CUDA(old_h); | ||
CHECK_CUDA(old_cell); | ||
|
||
return lltm_cuda_forward(input, weights, bias, old_h, old_cell); | ||
} | ||
|
||
std::vector<at::Tensor> lltm_backward( | ||
at::Tensor grad_h, | ||
at::Tensor grad_cell, | ||
at::Tensor new_cell, | ||
at::Tensor input_gate, | ||
at::Tensor output_gate, | ||
at::Tensor candidate_cell, | ||
at::Tensor X, | ||
at::Tensor gate_weights, | ||
at::Tensor weights) { | ||
CHECK_CUDA(grad_h); | ||
CHECK_CUDA(grad_cell); | ||
CHECK_CUDA(input_gate); | ||
CHECK_CUDA(output_gate); | ||
CHECK_CUDA(candidate_cell); | ||
CHECK_CUDA(X); | ||
CHECK_CUDA(gate_weights); | ||
CHECK_CUDA(weights); | ||
|
||
return lltm_cuda_backward( | ||
grad_h, | ||
grad_cell, | ||
new_cell, | ||
input_gate, | ||
output_gate, | ||
candidate_cell, | ||
X, | ||
gate_weights, | ||
weights); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def("forward", &lltm_forward, "LLTM forward (CUDA)"); | ||
m.def("backward", &lltm_backward, "LLTM backward (CUDA)"); | ||
} |
Oops, something went wrong.