Skip to content

Commit

Permalink
Better handling of contiguous tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
goldsborough committed Mar 17, 2018
1 parent 88e2a9a commit 63f005f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 22 deletions.
28 changes: 15 additions & 13 deletions cuda/lltm_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,20 @@ std::vector<at::Tensor> lltm_cuda_backward(
// C++ interface

#define CHECK_CUDA(x) AT_ASSERT(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERT(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
CHECK_CUDA(input);
CHECK_CUDA(weights);
CHECK_CUDA(bias);
CHECK_CUDA(old_h);
CHECK_CUDA(old_cell);
CHECK_INPUT(input);
CHECK_INPUT(weights);
CHECK_INPUT(bias);
CHECK_INPUT(old_h);
CHECK_INPUT(old_cell);

return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}
Expand All @@ -51,14 +53,14 @@ std::vector<at::Tensor> lltm_backward(
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
CHECK_CUDA(grad_h);
CHECK_CUDA(grad_cell);
CHECK_CUDA(input_gate);
CHECK_CUDA(output_gate);
CHECK_CUDA(candidate_cell);
CHECK_CUDA(X);
CHECK_CUDA(gate_weights);
CHECK_CUDA(weights);
CHECK_INPUT(grad_h);
CHECK_INPUT(grad_cell);
CHECK_INPUT(input_gate);
CHECK_INPUT(output_gate);
CHECK_INPUT(candidate_cell);
CHECK_INPUT(X);
CHECK_INPUT(gate_weights);
CHECK_INPUT(weights);

return lltm_cuda_backward(
grad_h,
Expand Down
18 changes: 9 additions & 9 deletions cuda/lltm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ std::vector<at::Tensor> lltm_cuda_forward(

AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gates.contiguous().data<scalar_t>(),
old_cell.contiguous().data<scalar_t>(),
gates.data<scalar_t>(),
old_cell.data<scalar_t>(),
new_h.data<scalar_t>(),
new_cell.data<scalar_t>(),
input_gate.data<scalar_t>(),
Expand Down Expand Up @@ -158,13 +158,13 @@ std::vector<at::Tensor> lltm_cuda_backward(
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
d_old_cell.data<scalar_t>(),
d_gates.data<scalar_t>(),
grad_h.contiguous().data<scalar_t>(),
grad_cell.contiguous().data<scalar_t>(),
new_cell.contiguous().data<scalar_t>(),
input_gate.contiguous().data<scalar_t>(),
output_gate.contiguous().data<scalar_t>(),
candidate_cell.contiguous().data<scalar_t>(),
gate_weights.contiguous().data<scalar_t>(),
grad_h.data<scalar_t>(),
grad_cell.data<scalar_t>(),
new_cell.data<scalar_t>(),
input_gate.data<scalar_t>(),
output_gate.data<scalar_t>(),
candidate_cell.data<scalar_t>(),
gate_weights.data<scalar_t>(),
state_size);
}));

Expand Down

0 comments on commit 63f005f

Please sign in to comment.