diff --git a/benchmark.py b/benchmark.py index 25779eb..212da08 100644 --- a/benchmark.py +++ b/benchmark.py @@ -17,6 +17,7 @@ parser.add_argument('-r', '--runs', type=int, default=100) parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us') parser.add_argument('-c', '--cuda', action='store_true') +parser.add_argument('-d', '--double', action='store_true') options = parser.parse_args() if options.example == 'py': @@ -27,16 +28,16 @@ from cuda.lltm import LLTM options.cuda = True -X = torch.randn(options.batch_size, options.features) -h = torch.randn(options.batch_size, options.state_size) -C = torch.randn(options.batch_size, options.state_size) -rnn = LLTM(options.features, options.state_size) +device = torch.device("cuda") if options.cuda else torch.device("cpu") +dtype = torch.float64 if options.double else torch.float32 -if options.cuda: - X = X.cuda() - h = h.cuda() - C = C.cuda() - rnn.cuda() +kwargs = {'dtype': dtype, + 'device': device, + 'requires_grad': True} +X = torch.randn(options.batch_size, options.features, **kwargs) +h = torch.randn(options.batch_size, options.state_size, **kwargs) +C = torch.randn(options.batch_size, options.state_size, **kwargs) +rnn = LLTM(options.features, options.state_size).to(device, dtype) # Force CUDA initialization new_h, new_C = rnn(X, (h, C)) diff --git a/check.py b/check.py index 4c0a5b0..8fad6d1 100644 --- a/check.py +++ b/check.py @@ -5,8 +5,6 @@ import numpy as np import torch -from torch.autograd import Variable - import python.lltm_baseline import cpp.lltm @@ -85,21 +83,23 @@ def check_backward(variables, with_cuda, verbose): if options.cuda: import cuda.lltm - options.cuda = True - -X = torch.randn(options.batch_size, options.features) -h = torch.randn(options.batch_size, options.state_size) -C = torch.randn(options.batch_size, options.state_size) -W = torch.randn(3 * options.state_size, options.features + options.state_size) -b = torch.randn(1, 3 * options.state_size) + device = torch.device("cuda") +else: + device = torch.device("cpu") + +kwargs = {'dtype': torch.float64, + 'device': device, + 'requires_grad': True} +X = torch.randn(options.batch_size, + options.features, + **kwargs) +h = torch.randn(options.batch_size, options.state_size, **kwargs) +C = torch.randn(options.batch_size, options.state_size, **kwargs) +W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs) +b = torch.randn(1, 3 * options.state_size, **kwargs) variables = [X, W, b, h, C] -for i, var in enumerate(variables): - if options.cuda: - var = var.cuda() - variables[i] = Variable(var.double(), requires_grad=True) - if 'forward' in options.direction: check_forward(variables, options.cuda, options.verbose) diff --git a/cpp/lltm.cpp b/cpp/lltm.cpp index a94d1d4..9bdfe0c 100644 --- a/cpp/lltm.cpp +++ b/cpp/lltm.cpp @@ -1,42 +1,42 @@ -#include +#include #include // s'(z) = (1 - s(z)) * s(z) -at::Tensor d_sigmoid(at::Tensor z) { - auto s = at::sigmoid(z); +torch::Tensor d_sigmoid(torch::Tensor z) { + auto s = torch::sigmoid(z); return (1 - s) * s; } // tanh'(z) = 1 - tanh^2(z) -at::Tensor d_tanh(at::Tensor z) { +torch::Tensor d_tanh(torch::Tensor z) { return 1 - z.tanh().pow(2); } // elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0} -at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) { +torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) { auto e = z.exp(); auto mask = (alpha * (e - 1)) < 0; return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e); } -std::vector lltm_forward( - at::Tensor input, - at::Tensor weights, - at::Tensor bias, - at::Tensor old_h, - at::Tensor old_cell) { - auto X = at::cat({old_h, input}, /*dim=*/1); +std::vector lltm_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); - auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1)); + auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1)); auto gates = gate_weights.chunk(3, /*dim=*/1); - auto input_gate = at::sigmoid(gates[0]); - auto output_gate = at::sigmoid(gates[1]); - auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0); + auto input_gate = torch::sigmoid(gates[0]); + auto output_gate = torch::sigmoid(gates[1]); + auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0); auto new_cell = old_cell + candidate_cell * input_gate; - auto new_h = at::tanh(new_cell) * output_gate; + auto new_h = torch::tanh(new_cell) * output_gate; return {new_h, new_cell, @@ -47,17 +47,17 @@ std::vector lltm_forward( gate_weights}; } -std::vector lltm_backward( - at::Tensor grad_h, - at::Tensor grad_cell, - at::Tensor new_cell, - at::Tensor input_gate, - at::Tensor output_gate, - at::Tensor candidate_cell, - at::Tensor X, - at::Tensor gate_weights, - at::Tensor weights) { - auto d_output_gate = at::tanh(new_cell) * grad_h; +std::vector lltm_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + auto d_output_gate = torch::tanh(new_cell) * grad_h; auto d_tanh_new_cell = output_gate * grad_h; auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell; @@ -71,7 +71,7 @@ std::vector lltm_backward( d_candidate_cell *= d_elu(gates[2]); auto d_gates = - at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); + torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1); auto d_weights = d_gates.t().mm(X); auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true); diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp index 781a667..2434776 100644 --- a/cuda/lltm_cuda.cpp +++ b/cuda/lltm_cuda.cpp @@ -1,26 +1,26 @@ -#include +#include #include // CUDA forward declarations -std::vector lltm_cuda_forward( - at::Tensor input, - at::Tensor weights, - at::Tensor bias, - at::Tensor old_h, - at::Tensor old_cell); +std::vector lltm_cuda_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell); -std::vector lltm_cuda_backward( - at::Tensor grad_h, - at::Tensor grad_cell, - at::Tensor new_cell, - at::Tensor input_gate, - at::Tensor output_gate, - at::Tensor candidate_cell, - at::Tensor X, - at::Tensor gate_weights, - at::Tensor weights); +std::vector lltm_cuda_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights); // C++ interface @@ -29,12 +29,12 @@ std::vector lltm_cuda_backward( #define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) -std::vector lltm_forward( - at::Tensor input, - at::Tensor weights, - at::Tensor bias, - at::Tensor old_h, - at::Tensor old_cell) { +std::vector lltm_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { CHECK_INPUT(input); CHECK_INPUT(weights); CHECK_INPUT(bias); @@ -44,16 +44,16 @@ std::vector lltm_forward( return lltm_cuda_forward(input, weights, bias, old_h, old_cell); } -std::vector lltm_backward( - at::Tensor grad_h, - at::Tensor grad_cell, - at::Tensor new_cell, - at::Tensor input_gate, - at::Tensor output_gate, - at::Tensor candidate_cell, - at::Tensor X, - at::Tensor gate_weights, - at::Tensor weights) { +std::vector lltm_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { CHECK_INPUT(grad_h); CHECK_INPUT(grad_cell); CHECK_INPUT(input_gate); diff --git a/cuda/lltm_cuda_kernel.cu b/cuda/lltm_cuda_kernel.cu index 9e1ad04..e8759fb 100644 --- a/cuda/lltm_cuda_kernel.cu +++ b/cuda/lltm_cuda_kernel.cu @@ -1,4 +1,4 @@ -#include +#include #include #include @@ -99,23 +99,23 @@ __global__ void lltm_cuda_backward_kernel( } } // namespace -std::vector lltm_cuda_forward( - at::Tensor input, - at::Tensor weights, - at::Tensor bias, - at::Tensor old_h, - at::Tensor old_cell) { - auto X = at::cat({old_h, input}, /*dim=*/1); - auto gates = at::addmm(bias, X, weights.transpose(0, 1)); +std::vector lltm_cuda_forward( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + torch::Tensor old_h, + torch::Tensor old_cell) { + auto X = torch::cat({old_h, input}, /*dim=*/1); + auto gates = torch::addmm(bias, X, weights.transpose(0, 1)); const auto batch_size = old_cell.size(0); const auto state_size = old_cell.size(1); - auto new_h = at::zeros_like(old_cell); - auto new_cell = at::zeros_like(old_cell); - auto input_gate = at::zeros_like(old_cell); - auto output_gate = at::zeros_like(old_cell); - auto candidate_cell = at::zeros_like(old_cell); + auto new_h = torch::zeros_like(old_cell); + auto new_cell = torch::zeros_like(old_cell); + auto input_gate = torch::zeros_like(old_cell); + auto output_gate = torch::zeros_like(old_cell); + auto candidate_cell = torch::zeros_like(old_cell); const int threads = 1024; const dim3 blocks((state_size + threads - 1) / threads, batch_size); @@ -135,18 +135,18 @@ std::vector lltm_cuda_forward( return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates}; } -std::vector lltm_cuda_backward( - at::Tensor grad_h, - at::Tensor grad_cell, - at::Tensor new_cell, - at::Tensor input_gate, - at::Tensor output_gate, - at::Tensor candidate_cell, - at::Tensor X, - at::Tensor gate_weights, - at::Tensor weights) { - auto d_old_cell = at::zeros_like(new_cell); - auto d_gates = at::zeros_like(gate_weights); +std::vector lltm_cuda_backward( + torch::Tensor grad_h, + torch::Tensor grad_cell, + torch::Tensor new_cell, + torch::Tensor input_gate, + torch::Tensor output_gate, + torch::Tensor candidate_cell, + torch::Tensor X, + torch::Tensor gate_weights, + torch::Tensor weights) { + auto d_old_cell = torch::zeros_like(new_cell); + auto d_gates = torch::zeros_like(gate_weights); const auto batch_size = new_cell.size(0); const auto state_size = new_cell.size(1); diff --git a/grad_check.py b/grad_check.py index 1e0523b..caf3b36 100644 --- a/grad_check.py +++ b/grad_check.py @@ -3,8 +3,7 @@ import argparse import torch - -from torch.autograd import Variable, gradcheck +from torch.autograd import gradcheck parser = argparse.ArgumentParser() parser.add_argument('example', choices=['py', 'cpp', 'cuda']) @@ -22,18 +21,20 @@ from cuda.lltm import LLTMFunction options.cuda = True -X = torch.randn(options.batch_size, options.features) -h = torch.randn(options.batch_size, options.state_size) -C = torch.randn(options.batch_size, options.state_size) -W = torch.randn(3 * options.state_size, options.features + options.state_size) -b = torch.randn(1, 3 * options.state_size) +device = torch.device("cuda") if options.cuda else torch.device("cpu") + +kwargs = {'dtype': torch.float64, + 'device': device, + 'requires_grad': True} + +X = torch.randn(options.batch_size, options.features, **kwargs) +h = torch.randn(options.batch_size, options.state_size, **kwargs) +C = torch.randn(options.batch_size, options.state_size, **kwargs) +W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs) +b = torch.randn(1, 3 * options.state_size, **kwargs) variables = [X, W, b, h, C] -for i, var in enumerate(variables): - if options.cuda: - var = var.cuda() - variables[i] = Variable(var.double(), requires_grad=True) if gradcheck(LLTMFunction.apply, variables): print('Ok')