Skip to content

Commit

Permalink
Merge pull request pytorch#17 from dhpollack/fix_deprecated_warnings
Browse files Browse the repository at this point in the history
change sigmoid and tanh to torch from nn.functional
  • Loading branch information
goldsborough authored Aug 9, 2018
2 parents ae4a541 + 5be8789 commit eea6d31
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
6 changes: 3 additions & 3 deletions python/lltm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ def forward(self, input, state):
# Split the combined gate weight matrix into its components.
gates = gate_weights.chunk(3, dim=1)

input_gate = F.sigmoid(gates[0])
output_gate = F.sigmoid(gates[1])
input_gate = torch.sigmoid(gates[0])
output_gate = torch.sigmoid(gates[1])
# Here we use an ELU instead of the usual tanh.
candidate_cell = F.elu(gates[2])

# Compute the new cell state.
new_cell = old_cell + candidate_cell * input_gate
# Compute the new hidden state and output.
new_h = F.tanh(new_cell) * output_gate
new_h = torch.tanh(new_cell) * output_gate

return new_h, new_cell
12 changes: 6 additions & 6 deletions python/lltm_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@


def d_sigmoid(z):
s = F.sigmoid(z)
s = torch.sigmoid(z)
return (1 - s) * s


def d_tanh(z):
t = F.tanh(z)
t = torch.tanh(z)
return 1 - (t * t)


Expand All @@ -32,12 +32,12 @@ def forward(ctx, input, weights, bias, old_h, old_cell):
gate_weights = F.linear(X, weights, bias)
gates = gate_weights.chunk(3, dim=1)

input_gate = F.sigmoid(gates[0])
output_gate = F.sigmoid(gates[1])
input_gate = torch.sigmoid(gates[0])
output_gate = torch.sigmoid(gates[1])
candidate_cell = F.elu(gates[2])

new_cell = old_cell + candidate_cell * input_gate
new_h = F.tanh(new_cell) * output_gate
new_h = torch.tanh(new_cell) * output_gate

ctx.save_for_backward(X, weights, input_gate, output_gate, old_cell,
new_cell, candidate_cell, gate_weights)
Expand All @@ -51,7 +51,7 @@ def backward(ctx, grad_h, grad_cell):

d_input = d_weights = d_bias = d_old_h = d_old_cell = None

d_output_gate = F.tanh(new_cell) * grad_h
d_output_gate = torch.tanh(new_cell) * grad_h
d_tanh_new_cell = output_gate * grad_h
d_new_cell = d_tanh(new_cell) * d_tanh_new_cell + grad_cell

Expand Down

0 comments on commit eea6d31

Please sign in to comment.