Skip to content

Commit

Permalink
update utils.logsumexp
Browse files Browse the repository at this point in the history
  • Loading branch information
konas122 committed Mar 4, 2024
1 parent 947a1c3 commit 2caccd2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dazero/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def dropout(x, dropout_ratio=0.5):
return x


# ================== batch_norm / embed_id / layer_norm =====================
# ================== batch_norm / embedding / layer_norm =====================

class BatchNorm2d(Function):
def __init__(self, mean, var, decay, eps):
Expand Down
2 changes: 1 addition & 1 deletion dazero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def reshape_sum_backward(gy, x_shape, axis, keepdims):

def logsumexp(x, axis=1):
xp = cuda.get_array_module(x)
m = x.max(axis=axis, keepdims=True)
m = x.max(axis=axis, keepdims=True).astype(np.float32)
y = x - m
xp.exp(y, out=y)
s = y.sum(axis=axis, keepdims=True)
Expand Down

0 comments on commit 2caccd2

Please sign in to comment.