Skip to content

Commit

Permalink
Merge pull request pytorch#33 from ClementPinard/packed_accessor
Browse files Browse the repository at this point in the history
Use packed_accessor
  • Loading branch information
soumith authored Apr 17, 2019
2 parents 07b1598 + 11ce647 commit 85bc391
Showing 1 changed file with 68 additions and 73 deletions.
141 changes: 68 additions & 73 deletions cuda/lltm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,64 +37,59 @@ __device__ __forceinline__ scalar_t d_elu(scalar_t z, scalar_t alpha = 1.0) {

template <typename scalar_t>
__global__ void lltm_cuda_forward_kernel(
const scalar_t* __restrict__ gates,
const scalar_t* __restrict__ old_cell,
scalar_t* __restrict__ new_h,
scalar_t* __restrict__ new_cell,
scalar_t* __restrict__ input_gate,
scalar_t* __restrict__ output_gate,
scalar_t* __restrict__ candidate_cell,
size_t state_size) {
const int column = blockIdx.x * blockDim.x + threadIdx.x;
const int index = blockIdx.y * state_size + column;
const int gates_row = blockIdx.y * (state_size * 3);
if (column < state_size) {
input_gate[index] = sigmoid(gates[gates_row + column]);
output_gate[index] = sigmoid(gates[gates_row + state_size + column]);
candidate_cell[index] = elu(gates[gates_row + 2 * state_size + column]);
new_cell[index] =
old_cell[index] + candidate_cell[index] * input_gate[index];
new_h[index] = tanh(new_cell[index]) * output_gate[index];
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> old_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_h,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < gates.size(2)){
input_gate[n][c] = sigmoid(gates[n][0][c]);
output_gate[n][c] = sigmoid(gates[n][1][c]);
candidate_cell[n][c] = elu(gates[n][2][c]);
new_cell[n][c] =
old_cell[n][c] + candidate_cell[n][c] * input_gate[n][c];
new_h[n][c] = tanh(new_cell[n][c]) * output_gate[n][c];
}
}

template <typename scalar_t>
__global__ void lltm_cuda_backward_kernel(
scalar_t* __restrict__ d_old_cell,
scalar_t* __restrict__ d_gates,
const scalar_t* __restrict__ grad_h,
const scalar_t* __restrict__ grad_cell,
const scalar_t* __restrict__ new_cell,
const scalar_t* __restrict__ input_gate,
const scalar_t* __restrict__ output_gate,
const scalar_t* __restrict__ candidate_cell,
const scalar_t* __restrict__ gate_weights,
size_t state_size) {
const int column = blockIdx.x * blockDim.x + threadIdx.x;
const int index = blockIdx.y * state_size + column;
const int gates_row = blockIdx.y * (state_size * 3);
if (column < state_size) {
const auto d_output_gate = tanh(new_cell[index]) * grad_h[index];
const auto d_tanh_new_cell = output_gate[index] * grad_h[index];
torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> d_old_cell,
torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> d_gates,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_h,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> grad_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> new_cell,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> input_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> output_gate,
const torch::PackedTensorAccessor<scalar_t,2,torch::RestrictPtrTraits,size_t> candidate_cell,
const torch::PackedTensorAccessor<scalar_t,3,torch::RestrictPtrTraits,size_t> gate_weights) {
//batch index
const int n = blockIdx.y;
// column index
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < d_gates.size(2)){
const auto d_output_gate = tanh(new_cell[n][c]) * grad_h[n][c];
const auto d_tanh_new_cell = output_gate[n][c] * grad_h[n][c];
const auto d_new_cell =
d_tanh(new_cell[index]) * d_tanh_new_cell + grad_cell[index];
d_tanh(new_cell[n][c]) * d_tanh_new_cell + grad_cell[n][c];


d_old_cell[index] = d_new_cell;
const auto d_candidate_cell = input_gate[index] * d_new_cell;
const auto d_input_gate = candidate_cell[index] * d_new_cell;
d_old_cell[n][c] = d_new_cell;
const auto d_candidate_cell = input_gate[n][c] * d_new_cell;
const auto d_input_gate = candidate_cell[n][c] * d_new_cell;


const auto input_gate_index = gates_row + column;
const auto output_gate_index = gates_row + state_size + column;
const auto candidate_cell_index = gates_row + 2 * state_size + column;

d_gates[input_gate_index] =
d_input_gate * d_sigmoid(gate_weights[input_gate_index]);
d_gates[output_gate_index] =
d_output_gate * d_sigmoid(gate_weights[output_gate_index]);
d_gates[candidate_cell_index] =
d_candidate_cell * d_elu(gate_weights[candidate_cell_index]);
d_gates[n][0][c] =
d_input_gate * d_sigmoid(gate_weights[n][0][c]);
d_gates[n][1][c] =
d_output_gate * d_sigmoid(gate_weights[n][1][c]);
d_gates[n][2][c] =
d_candidate_cell * d_elu(gate_weights[n][2][c]);
}
}
} // namespace
Expand All @@ -106,11 +101,12 @@ std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));

const auto batch_size = old_cell.size(0);
const auto state_size = old_cell.size(1);

auto gates = gate_weights.reshape({batch_size, 3, state_size});
auto new_h = torch::zeros_like(old_cell);
auto new_cell = torch::zeros_like(old_cell);
auto input_gate = torch::zeros_like(old_cell);
Expand All @@ -122,14 +118,13 @@ std::vector<torch::Tensor> lltm_cuda_forward(

AT_DISPATCH_FLOATING_TYPES(gates.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_forward_kernel<scalar_t><<<blocks, threads>>>(
gates.data<scalar_t>(),
old_cell.data<scalar_t>(),
new_h.data<scalar_t>(),
new_cell.data<scalar_t>(),
input_gate.data<scalar_t>(),
output_gate.data<scalar_t>(),
candidate_cell.data<scalar_t>(),
state_size);
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>());
}));

return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
Expand All @@ -143,10 +138,10 @@ std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor gates,
torch::Tensor weights) {
auto d_old_cell = torch::zeros_like(new_cell);
auto d_gates = torch::zeros_like(gate_weights);
auto d_gates = torch::zeros_like(gates);

const auto batch_size = new_cell.size(0);
const auto state_size = new_cell.size(1);
Expand All @@ -156,22 +151,22 @@ std::vector<torch::Tensor> lltm_cuda_backward(

AT_DISPATCH_FLOATING_TYPES(X.type(), "lltm_forward_cuda", ([&] {
lltm_cuda_backward_kernel<scalar_t><<<blocks, threads>>>(
d_old_cell.data<scalar_t>(),
d_gates.data<scalar_t>(),
grad_h.data<scalar_t>(),
grad_cell.data<scalar_t>(),
new_cell.data<scalar_t>(),
input_gate.data<scalar_t>(),
output_gate.data<scalar_t>(),
candidate_cell.data<scalar_t>(),
gate_weights.data<scalar_t>(),
state_size);
d_old_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
d_gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>(),
grad_h.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
grad_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
new_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
input_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
output_gate.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
candidate_cell.packed_accessor<scalar_t,2,torch::RestrictPtrTraits,size_t>(),
gates.packed_accessor<scalar_t,3,torch::RestrictPtrTraits,size_t>());
}));

auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
auto d_gate_weights = d_gates.flatten(1, 2);
auto d_weights = d_gate_weights.t().mm(X);
auto d_bias = d_gate_weights.sum(/*dim=*/0, /*keepdim=*/true);

auto d_X = d_gates.mm(weights);
auto d_X = d_gate_weights.mm(weights);
auto d_old_h = d_X.slice(/*dim=*/1, 0, state_size);
auto d_input = d_X.slice(/*dim=*/1, state_size);

Expand Down

0 comments on commit 85bc391

Please sign in to comment.