-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconv_gru_cell.py
31 lines (24 loc) · 1.1 KB
/
conv_gru_cell.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import tensorflow as tf
import tensorflow.contrib.layers as tl
from tensorflow.contrib.rnn import RNNCell
class ConvGRUCell(RNNCell):
def __init__(self, hidden_channels, dims):
self._output_size = tf.TensorShape([*dims, hidden_channels])
self.dims = dims
self.hidden_channels = hidden_channels
@property
def output_size(self): return self._output_size
@property
def state_size(self): return self._output_size
def __call__(self, inputs, state, scope=None):
with tf.variable_scope(scope or type(self).__name__):
input_and_hidden = tf.concat([state, inputs], 3)
z = tl.conv2d(input_and_hidden, self.hidden_channels, self.dims,
biases_initializer=None, activation_fn=tf.sigmoid)
r = tl.conv2d(input_and_hidden, self.hidden_channels, self.dims,
biases_initializer=None, activation_fn=tf.sigmoid)
input_and_updated = tf.concat([r * state, inputs], 3)
h_tilde = tl.conv2d(input_and_updated, self.hidden_channels, self.dims,
biases_initializer=None, activation_fn=tf.tanh)
h = (1 - z) * state + z * h_tilde
return h, h