From 63f005fea0cac46488551979ffaa9f4b8ce7d117 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Fri, 16 Mar 2018 17:42:33 -0700 Subject: [PATCH] Better handling of contiguous tensors --- cuda/lltm_cuda.cpp | 28 +++++++++++++++------------- cuda/lltm_cuda_kernel.cu | 18 +++++++++--------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/cuda/lltm_cuda.cpp b/cuda/lltm_cuda.cpp index 6539162..e0c9c23 100644 --- a/cuda/lltm_cuda.cpp +++ b/cuda/lltm_cuda.cpp @@ -25,6 +25,8 @@ std::vector lltm_cuda_backward( // C++ interface #define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) std::vector lltm_forward( at::Tensor input, @@ -32,11 +34,11 @@ std::vector lltm_forward( at::Tensor bias, at::Tensor old_h, at::Tensor old_cell) { - CHECK_CUDA(input); - CHECK_CUDA(weights); - CHECK_CUDA(bias); - CHECK_CUDA(old_h); - CHECK_CUDA(old_cell); + CHECK_INPUT(input); + CHECK_INPUT(weights); + CHECK_INPUT(bias); + CHECK_INPUT(old_h); + CHECK_INPUT(old_cell); return lltm_cuda_forward(input, weights, bias, old_h, old_cell); } @@ -51,14 +53,14 @@ std::vector lltm_backward( at::Tensor X, at::Tensor gate_weights, at::Tensor weights) { - CHECK_CUDA(grad_h); - CHECK_CUDA(grad_cell); - CHECK_CUDA(input_gate); - CHECK_CUDA(output_gate); - CHECK_CUDA(candidate_cell); - CHECK_CUDA(X); - CHECK_CUDA(gate_weights); - CHECK_CUDA(weights); + CHECK_INPUT(grad_h); + CHECK_INPUT(grad_cell); + CHECK_INPUT(input_gate); + CHECK_INPUT(output_gate); + CHECK_INPUT(candidate_cell); + CHECK_INPUT(X); + CHECK_INPUT(gate_weights); + CHECK_INPUT(weights); return lltm_cuda_backward( grad_h, diff --git a/cuda/lltm_cuda_kernel.cu b/cuda/lltm_cuda_kernel.cu index 304da47..9e1ad04 100644 --- a/cuda/lltm_cuda_kernel.cu +++ b/cuda/lltm_cuda_kernel.cu @@ -122,8 +122,8 @@ std::vector lltm_cuda_forward( AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] { lltm_cuda_forward_kernel<<>>( - gates.contiguous().data(), - old_cell.contiguous().data(), + gates.data(), + old_cell.data(), new_h.data(), new_cell.data(), input_gate.data(), @@ -158,13 +158,13 @@ std::vector lltm_cuda_backward( lltm_cuda_backward_kernel<<>>( d_old_cell.data(), d_gates.data(), - grad_h.contiguous().data(), - grad_cell.contiguous().data(), - new_cell.contiguous().data(), - input_gate.contiguous().data(), - output_gate.contiguous().data(), - candidate_cell.contiguous().data(), - gate_weights.contiguous().data(), + grad_h.data(), + grad_cell.data(), + new_cell.data(), + input_gate.data(), + output_gate.data(), + candidate_cell.data(), + gate_weights.data(), state_size); }));