diff --git a/dazero/functions.py b/dazero/functions.py index 1c4161e..de6adc2 100644 --- a/dazero/functions.py +++ b/dazero/functions.py @@ -489,7 +489,7 @@ def dropout(x, dropout_ratio=0.5): return x -# ================== batch_norm / embed_id / layer_norm ===================== +# ================== batch_norm / embedding / layer_norm ===================== class BatchNorm2d(Function): def __init__(self, mean, var, decay, eps): diff --git a/dazero/utils.py b/dazero/utils.py index 1178c95..b0d3362 100644 --- a/dazero/utils.py +++ b/dazero/utils.py @@ -203,7 +203,7 @@ def reshape_sum_backward(gy, x_shape, axis, keepdims): def logsumexp(x, axis=1): xp = cuda.get_array_module(x) - m = x.max(axis=axis, keepdims=True) + m = x.max(axis=axis, keepdims=True).astype(np.float32) y = x - m xp.exp(y, out=y) s = y.sum(axis=axis, keepdims=True)