Skip to content

Commit

Permalink
Merge pull request pytorch#32 from ClementPinard/fix-deprecated
Browse files Browse the repository at this point in the history
Fix deprecated functions
  • Loading branch information
soumith authored Apr 16, 2019
2 parents eea6d31 + 4a86842 commit 07b1598
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 122 deletions.
19 changes: 10 additions & 9 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
parser.add_argument('-r', '--runs', type=int, default=100)
parser.add_argument('--scale', choices=['s', 'ms', 'us'], default='us')
parser.add_argument('-c', '--cuda', action='store_true')
parser.add_argument('-d', '--double', action='store_true')
options = parser.parse_args()

if options.example == 'py':
Expand All @@ -27,16 +28,16 @@
from cuda.lltm import LLTM
options.cuda = True

X = torch.randn(options.batch_size, options.features)
h = torch.randn(options.batch_size, options.state_size)
C = torch.randn(options.batch_size, options.state_size)
rnn = LLTM(options.features, options.state_size)
device = torch.device("cuda") if options.cuda else torch.device("cpu")
dtype = torch.float64 if options.double else torch.float32

if options.cuda:
X = X.cuda()
h = h.cuda()
C = C.cuda()
rnn.cuda()
kwargs = {'dtype': dtype,
'device': device,
'requires_grad': True}
X = torch.randn(options.batch_size, options.features, **kwargs)
h = torch.randn(options.batch_size, options.state_size, **kwargs)
C = torch.randn(options.batch_size, options.state_size, **kwargs)
rnn = LLTM(options.features, options.state_size).to(device, dtype)

# Force CUDA initialization
new_h, new_C = rnn(X, (h, C))
Expand Down
28 changes: 14 additions & 14 deletions check.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import numpy as np
import torch

from torch.autograd import Variable

import python.lltm_baseline
import cpp.lltm

Expand Down Expand Up @@ -85,21 +83,23 @@ def check_backward(variables, with_cuda, verbose):

if options.cuda:
import cuda.lltm
options.cuda = True

X = torch.randn(options.batch_size, options.features)
h = torch.randn(options.batch_size, options.state_size)
C = torch.randn(options.batch_size, options.state_size)
W = torch.randn(3 * options.state_size, options.features + options.state_size)
b = torch.randn(1, 3 * options.state_size)
device = torch.device("cuda")
else:
device = torch.device("cpu")

kwargs = {'dtype': torch.float64,
'device': device,
'requires_grad': True}
X = torch.randn(options.batch_size,
options.features,
**kwargs)
h = torch.randn(options.batch_size, options.state_size, **kwargs)
C = torch.randn(options.batch_size, options.state_size, **kwargs)
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
b = torch.randn(1, 3 * options.state_size, **kwargs)

variables = [X, W, b, h, C]

for i, var in enumerate(variables):
if options.cuda:
var = var.cuda()
variables[i] = Variable(var.double(), requires_grad=True)

if 'forward' in options.direction:
check_forward(variables, options.cuda, options.verbose)

Expand Down
58 changes: 29 additions & 29 deletions cpp/lltm.cpp
Original file line number Diff line number Diff line change
@@ -1,42 +1,42 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <vector>

// s'(z) = (1 - s(z)) * s(z)
at::Tensor d_sigmoid(at::Tensor z) {
auto s = at::sigmoid(z);
torch::Tensor d_sigmoid(torch::Tensor z) {
auto s = torch::sigmoid(z);
return (1 - s) * s;
}

// tanh'(z) = 1 - tanh^2(z)
at::Tensor d_tanh(at::Tensor z) {
torch::Tensor d_tanh(torch::Tensor z) {
return 1 - z.tanh().pow(2);
}

// elu'(z) = relu'(z) + { alpha * exp(z) if (alpha * (exp(z) - 1)) < 0, else 0}
at::Tensor d_elu(at::Tensor z, at::Scalar alpha = 1.0) {
torch::Tensor d_elu(torch::Tensor z, torch::Scalar alpha = 1.0) {
auto e = z.exp();
auto mask = (alpha * (e - 1)) < 0;
return (z > 0).type_as(z) + mask.type_as(z) * (alpha * e);
}

std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
auto X = at::cat({old_h, input}, /*dim=*/1);
std::vector<torch::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);

auto gate_weights = at::addmm(bias, X, weights.transpose(0, 1));
auto gate_weights = torch::addmm(bias, X, weights.transpose(0, 1));
auto gates = gate_weights.chunk(3, /*dim=*/1);

auto input_gate = at::sigmoid(gates[0]);
auto output_gate = at::sigmoid(gates[1]);
auto candidate_cell = at::elu(gates[2], /*alpha=*/1.0);
auto input_gate = torch::sigmoid(gates[0]);
auto output_gate = torch::sigmoid(gates[1]);
auto candidate_cell = torch::elu(gates[2], /*alpha=*/1.0);

auto new_cell = old_cell + candidate_cell * input_gate;
auto new_h = at::tanh(new_cell) * output_gate;
auto new_h = torch::tanh(new_cell) * output_gate;

return {new_h,
new_cell,
Expand All @@ -47,17 +47,17 @@ std::vector<at::Tensor> lltm_forward(
gate_weights};
}

std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
auto d_output_gate = at::tanh(new_cell) * grad_h;
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
auto d_output_gate = torch::tanh(new_cell) * grad_h;
auto d_tanh_new_cell = output_gate * grad_h;
auto d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell;

Expand All @@ -71,7 +71,7 @@ std::vector<at::Tensor> lltm_backward(
d_candidate_cell *= d_elu(gates[2]);

auto d_gates =
at::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);
torch::cat({d_input_gate, d_output_gate, d_candidate_cell}, /*dim=*/1);

auto d_weights = d_gates.t().mm(X);
auto d_bias = d_gates.sum(/*dim=*/0, /*keepdim=*/true);
Expand Down
66 changes: 33 additions & 33 deletions cuda/lltm_cuda.cpp
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
#include <torch/torch.h>
#include <torch/extension.h>

#include <vector>

// CUDA forward declarations

std::vector<at::Tensor> lltm_cuda_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell);
std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell);

std::vector<at::Tensor> lltm_cuda_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights);
std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights);

// C++ interface

Expand All @@ -29,12 +29,12 @@ std::vector<at::Tensor> lltm_cuda_backward(
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> lltm_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
std::vector<torch::Tensor> lltm_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
CHECK_INPUT(input);
CHECK_INPUT(weights);
CHECK_INPUT(bias);
Expand All @@ -44,16 +44,16 @@ std::vector<at::Tensor> lltm_forward(
return lltm_cuda_forward(input, weights, bias, old_h, old_cell);
}

std::vector<at::Tensor> lltm_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
std::vector<torch::Tensor> lltm_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
CHECK_INPUT(grad_h);
CHECK_INPUT(grad_cell);
CHECK_INPUT(input_gate);
Expand Down
52 changes: 26 additions & 26 deletions cuda/lltm_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include <ATen/ATen.h>
#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -99,23 +99,23 @@ __global__ void lltm_cuda_backward_kernel(
}
} // namespace

std::vector<at::Tensor> lltm_cuda_forward(
at::Tensor input,
at::Tensor weights,
at::Tensor bias,
at::Tensor old_h,
at::Tensor old_cell) {
auto X = at::cat({old_h, input}, /*dim=*/1);
auto gates = at::addmm(bias, X, weights.transpose(0, 1));
std::vector<torch::Tensor> lltm_cuda_forward(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
torch::Tensor old_h,
torch::Tensor old_cell) {
auto X = torch::cat({old_h, input}, /*dim=*/1);
auto gates = torch::addmm(bias, X, weights.transpose(0, 1));

const auto batch_size = old_cell.size(0);
const auto state_size = old_cell.size(1);

auto new_h = at::zeros_like(old_cell);
auto new_cell = at::zeros_like(old_cell);
auto input_gate = at::zeros_like(old_cell);
auto output_gate = at::zeros_like(old_cell);
auto candidate_cell = at::zeros_like(old_cell);
auto new_h = torch::zeros_like(old_cell);
auto new_cell = torch::zeros_like(old_cell);
auto input_gate = torch::zeros_like(old_cell);
auto output_gate = torch::zeros_like(old_cell);
auto candidate_cell = torch::zeros_like(old_cell);

const int threads = 1024;
const dim3 blocks((state_size + threads - 1) / threads, batch_size);
Expand All @@ -135,18 +135,18 @@ std::vector<at::Tensor> lltm_cuda_forward(
return {new_h, new_cell, input_gate, output_gate, candidate_cell, X, gates};
}

std::vector<at::Tensor> lltm_cuda_backward(
at::Tensor grad_h,
at::Tensor grad_cell,
at::Tensor new_cell,
at::Tensor input_gate,
at::Tensor output_gate,
at::Tensor candidate_cell,
at::Tensor X,
at::Tensor gate_weights,
at::Tensor weights) {
auto d_old_cell = at::zeros_like(new_cell);
auto d_gates = at::zeros_like(gate_weights);
std::vector<torch::Tensor> lltm_cuda_backward(
torch::Tensor grad_h,
torch::Tensor grad_cell,
torch::Tensor new_cell,
torch::Tensor input_gate,
torch::Tensor output_gate,
torch::Tensor candidate_cell,
torch::Tensor X,
torch::Tensor gate_weights,
torch::Tensor weights) {
auto d_old_cell = torch::zeros_like(new_cell);
auto d_gates = torch::zeros_like(gate_weights);

const auto batch_size = new_cell.size(0);
const auto state_size = new_cell.size(1);
Expand Down
23 changes: 12 additions & 11 deletions grad_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import argparse
import torch

from torch.autograd import Variable, gradcheck
from torch.autograd import gradcheck

parser = argparse.ArgumentParser()
parser.add_argument('example', choices=['py', 'cpp', 'cuda'])
Expand All @@ -22,18 +21,20 @@
from cuda.lltm import LLTMFunction
options.cuda = True

X = torch.randn(options.batch_size, options.features)
h = torch.randn(options.batch_size, options.state_size)
C = torch.randn(options.batch_size, options.state_size)
W = torch.randn(3 * options.state_size, options.features + options.state_size)
b = torch.randn(1, 3 * options.state_size)
device = torch.device("cuda") if options.cuda else torch.device("cpu")

kwargs = {'dtype': torch.float64,
'device': device,
'requires_grad': True}

X = torch.randn(options.batch_size, options.features, **kwargs)
h = torch.randn(options.batch_size, options.state_size, **kwargs)
C = torch.randn(options.batch_size, options.state_size, **kwargs)
W = torch.randn(3 * options.state_size, options.features + options.state_size, **kwargs)
b = torch.randn(1, 3 * options.state_size, **kwargs)

variables = [X, W, b, h, C]

for i, var in enumerate(variables):
if options.cuda:
var = var.cuda()
variables[i] = Variable(var.double(), requires_grad=True)

if gradcheck(LLTMFunction.apply, variables):
print('Ok')

0 comments on commit 07b1598

Please sign in to comment.