Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
goldsborough committed Mar 4, 2018
0 parents commit c081299
Show file tree
Hide file tree
Showing 17 changed files with 721 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# pytorch-cpp-extension

An example of writing a C++ extension for PyTorch.
72 changes: 72 additions & 0 deletions benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import division
from __future__ import print_function

import argparse
import math
import time

import torch

TIME_SCALES = {'s': 1, 'ms': 1000, 'us': 1000000}

parser = argparse.ArgumentParser()
parser.add_argument('example', choices=['py', 'cpp', 'cuda'])
parser.add_argument('-b', '--batch-size', type=int, default=16)
parser.add_argument('-f', '--features', type=int, default=32)
parser.add_argument('-s', '--state-size', type=int, default=128)
parser.add_argument('-r', '--runs', type=int, default=100)
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
parser.add_argument('-c', '--cuda', action='store_true')
options = parser.parse_args()

if options.example == 'py':
from python.lltm import LLTM
elif options.example == 'cpp':
from cpp.lltm import LLTM
else:
from cuda.lltm import LLTM
options.cuda = True

X = torch.randn(options.batch_size, options.features)
h = torch.randn(options.batch_size, options.state_size)
C = torch.randn(options.batch_size, options.state_size)
rnn = LLTM(options.features, options.state_size)

if options.cuda:
X = X.cuda()
h = h.cuda()
C = C.cuda()
rnn.cuda()

# Force CUDA initialization
new_h, new_C = rnn(X, (h, C))
(new_h.sum() + new_C.sum()).backward()

forward_min = math.inf
forward_time = 0
backward_min = math.inf
backward_time = 0
for _ in range(options.runs):
rnn.zero_grad()

start = time.time()
new_h, new_C = rnn(X, (h, C))
elapsed = time.time() - start
forward_min = min(forward_min, elapsed)
forward_time += elapsed

start = time.time()
(new_h.sum() + new_C.sum()).backward()
elapsed = time.time() - start
backward_min = min(backward_min, elapsed)
backward_time += elapsed

scale = TIME_SCALES[options.scale]
forward_min *= scale
backward_min *= scale
forward_average = forward_time / options.runs * scale
backward_average = backward_time / options.runs * scale

print('Forward: {0:.3f}/{1:.3f} {4} | Backward {2:.3f}/{3:.3f} {4}'.format(
forward_min, forward_average, backward_min, backward_average,
options.scale))
Empty file added cpp/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions cpp/jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from torch.utils.cpp_extension import load
lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"], verbose=True)
help(lltm_cpp)
90 changes: 90 additions & 0 deletions cpp/lltm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include <torch/torch.h>

#include <vector>

// s'(z) = (1 - s(z)) * s(z)
at::Tensor d_sigmoid(at::Tensor z) {
auto s = at::sigmoid(z);
return (1 - s) * s;
}

// tanh'(z) = 1 - tanh^2(z)
at::Tensor d_tanh(at::Tensor z) {
return 1 - z.tanh().pow(2);
}

// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}

std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
auto X = at::cat({old_h, input}, /*dim=*/1);

auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);

auto input_gate = at::sigmoid(gates[0]);
auto output_gate = at::sigmoid(gates[1]);
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);

auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = at::tanh(new_cell) * output_gate;

return {new_h,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights};
}

std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
auto d_output_gate = at::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;

auto d_old_cell = d_new_cell;
auto d_candidate_cell = input_gate * d_new_cell;
auto d_input_gate = candidate_cell * d_new_cell;

auto gates = gate_weights.chunk(3, /*dim=*/1);
d_input_gate *= d_sigmoid(gates[0]);
d_output_gate *= d_sigmoid(gates[1]);
d_candidate_cell *= d_elu(gates[2]);

auto d_gates =
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);

auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);

auto d_X = d_gates.mm(weights);
const auto state_size = grad_h.size(1);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);

return {d_old_h, d_input, d_weights, d_bias, d_old_cell};
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward");
m.def("backward", &lltm_backward, "LLTM backward");
}
44 changes: 44 additions & 0 deletions cpp/lltm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import math
from torch import nn
from torch.autograd import Function
import torch

import lltm_cpp

torch.manual_seed(42)


class LLTMFunction(Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm_cpp.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights]
ctx.save_for_backward(*variables)

return new_h, new_cell

@staticmethod
def backward(ctx, grad_h, grad_cell):
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cpp.backward(
grad_h, grad_cell, *ctx.saved_variables)
return d_input, d_weights, d_bias, d_old_h, d_old_cell


class LLTM(nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
self.input_features = input_features
self.state_size = state_size
self.weights = nn.Parameter(
torch.Tensor(3 * state_size, input_features + state_size))
self.bias = nn.Parameter(torch.Tensor(3 * state_size))
self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.state_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, +stdv)

def forward(self, input, state):
return LLTMFunction.apply(input, self.weights, self.bias, *state)
11 changes: 11 additions & 0 deletions cpp/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

setup(
name='lltm_cpp',
ext_modules=[
CppExtension('lltm_cpp', ['lltm.cpp']),
],
cmdclass={
'build_ext': BuildExtension
})
Empty file added cuda/__init__.py
Empty file.
4 changes: 4 additions & 0 deletions cuda/jit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from torch.utils.cpp_extension import load
lltm_cuda = load(
'lltm_cuda', ['lltm_cuda.cpp', 'lltm_cuda_kernel.cu'], verbose=True)
help(lltm_cuda)
44 changes: 44 additions & 0 deletions cuda/lltm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import math
from torch import nn
from torch.autograd import Function
import torch

import lltm_cuda

torch.manual_seed(42)


class LLTMFunction(Function):
@staticmethod
def forward(ctx, input, weights, bias, old_h, old_cell):
outputs = lltm_cuda.forward(input, weights, bias, old_h, old_cell)
new_h, new_cell = outputs[:2]
variables = outputs[1:] + [weights]
ctx.save_for_backward(*variables)

return new_h, new_cell

@staticmethod
def backward(ctx, grad_h, grad_cell):
d_old_h, d_input, d_weights, d_bias, d_old_cell = lltm_cuda.backward(
grad_h, grad_cell, *ctx.saved_variables)
return d_input, d_weights, d_bias, d_old_h, d_old_cell


class LLTM(nn.Module):
def __init__(self, input_features, state_size):
super(LLTM, self).__init__()
self.input_features = input_features
self.state_size = state_size
self.weights = nn.Parameter(
torch.Tensor(3 * state_size, input_features + state_size))
self.bias = nn.Parameter(torch.Tensor(3 * state_size))
self.reset_parameters()

def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.state_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, +stdv)

def forward(self, input, state):
return LLTMFunction.apply(input, self.weights, self.bias, *state)
78 changes: 78 additions & 0 deletions cuda/lltm_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <torch/torch.h>

#include <vector>

// CUDA forward declarations

std::vector<at::Tensor> lltm_cuda_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell);

std::vector<at::Tensor> lltm_cuda_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights);

// C++ interface

#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")

std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
CHECK_CUDA(input);
CHECK_CUDA(weights);
CHECK_CUDA(bias);
CHECK_CUDA(old_h);
CHECK_CUDA(old_cell);

return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}

std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
CHECK_CUDA(grad_h);
CHECK_CUDA(grad_cell);
CHECK_CUDA(input_gate);
CHECK_CUDA(output_gate);
CHECK_CUDA(candidate_cell);
CHECK_CUDA(X);
CHECK_CUDA(gate_weights);
CHECK_CUDA(weights);

return lltm_cuda_backward(
grad_h,
grad_cell,
new_cell,
input_gate,
output_gate,
candidate_cell,
X,
gate_weights,
weights);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &lltm_forward, "LLTM forward (CUDA)");
m.def("backward", &lltm_backward, "LLTM backward (CUDA)");
}
Loading

0 comments on commit c081299

Please sign in to comment.